1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# Copyright (c) 2021 Intel Corporation
4
5# A script to generate twister options based on modified files.
6
7import re, os
8import argparse
9import glob
10import yaml
11import json
12import fnmatch
13import subprocess
14import csv
15import logging
16from git import Repo
17
18if "ZEPHYR_BASE" not in os.environ:
19    exit("$ZEPHYR_BASE environment variable undefined.")
20
21repository_path = os.environ['ZEPHYR_BASE']
22logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
23
24def _get_match_fn(globs, regexes):
25    # Constructs a single regex that tests for matches against the globs in
26    # 'globs' and the regexes in 'regexes'. Parts are joined with '|' (OR).
27    # Returns the search() method of the compiled regex.
28    #
29    # Returns None if there are neither globs nor regexes, which should be
30    # interpreted as no match.
31
32    if not (globs or regexes):
33        return None
34
35    regex = ""
36
37    if globs:
38        glob_regexes = []
39        for glob in globs:
40            # Construct a regex equivalent to the glob
41            glob_regex = glob.replace(".", "\\.").replace("*", "[^/]*") \
42                             .replace("?", "[^/]")
43
44            if not glob.endswith("/"):
45                # Require a full match for globs that don't end in /
46                glob_regex += "$"
47
48            glob_regexes.append(glob_regex)
49
50        # The glob regexes must anchor to the beginning of the path, since we
51        # return search(). (?:) is a non-capturing group.
52        regex += "^(?:{})".format("|".join(glob_regexes))
53
54    if regexes:
55        if regex:
56            regex += "|"
57        regex += "|".join(regexes)
58
59    return re.compile(regex).search
60
61class Tag:
62    """
63    Represents an entry for a tag in tags.yaml.
64
65    These attributes are available:
66
67    name:
68        List of GitHub labels for the area. Empty if the area has no 'labels'
69        key.
70
71    description:
72        Text from 'description' key, or None if the area has no 'description'
73        key
74    """
75    def _contains(self, path):
76        # Returns True if the area contains 'path', and False otherwise
77
78        return self._match_fn and self._match_fn(path) and not \
79            (self._exclude_match_fn and self._exclude_match_fn(path))
80
81    def __repr__(self):
82        return "<Tag {}>".format(self.name)
83
84class Filters:
85    def __init__(self, modified_files, pull_request=False, platforms=[]):
86        self.modified_files = modified_files
87        self.twister_options = []
88        self.full_twister = False
89        self.all_tests = []
90        self.tag_options = []
91        self.pull_request = pull_request
92        self.platforms = platforms
93
94
95    def process(self):
96        self.find_tags()
97        self.find_excludes()
98        self.find_tests()
99        if not self.platforms:
100            self.find_archs()
101        self.find_boards()
102
103    def get_plan(self, options, integration=False):
104        fname = "_test_plan_partial.csv"
105        cmd = ["scripts/twister", "-c"] + options + ["--save-tests", fname ]
106        if integration:
107            cmd.append("--integration")
108
109        logging.info(" ".join(cmd))
110        _ = subprocess.call(cmd)
111        with open(fname, newline='') as csvfile:
112            csv_reader = csv.reader(csvfile, delimiter=',')
113            _ = next(csv_reader)
114            for e in csv_reader:
115                self.all_tests.append(e)
116        if os.path.exists(fname):
117            os.remove(fname)
118
119    def find_archs(self):
120        # we match both arch/<arch>/* and include/arch/<arch> and skip common.
121        # Some architectures like riscv require special handling, i.e. riscv
122        # directory covers 2 architectures known to twister: riscv32 and riscv64.
123        archs = set()
124
125        for f in self.modified_files:
126            p = re.match(r"^arch\/([^/]+)\/", f)
127            if not p:
128                p = re.match(r"^include\/arch\/([^/]+)\/", f)
129            if p:
130                if p.group(1) != 'common':
131                    if p.group(1) == 'riscv':
132                        archs.add('riscv32')
133                        archs.add('riscv64')
134                    else:
135                        archs.add(p.group(1))
136
137        _options = []
138        for arch in archs:
139            _options.extend(["-a", arch ])
140
141        if _options:
142            logging.info(f'Potential architecture filters...')
143            if self.platforms:
144                for platform in self.platforms:
145                    _options.extend(["-p", platform])
146
147                self.get_plan(_options, True)
148            else:
149                self.get_plan(_options, False)
150
151    def find_boards(self):
152        boards = set()
153        all_boards = set()
154
155        for f in self.modified_files:
156            if f.endswith(".rst") or f.endswith(".png") or f.endswith(".jpg"):
157                continue
158            p = re.match(r"^boards\/[^/]+\/([^/]+)\/", f)
159            if p and p.groups():
160                boards.add(p.group(1))
161
162        for b in boards:
163            suboards = glob.glob("boards/*/%s/*.yaml" %(b))
164            for subboard in suboards:
165                name = os.path.splitext(os.path.basename(subboard))[0]
166                if name:
167                    all_boards.add(name)
168
169        _options = []
170        for board in all_boards:
171            _options.extend(["-p", board ])
172
173        if _options:
174            logging.info(f'Potential board filters...')
175            self.get_plan(_options)
176
177    def find_tests(self):
178        tests = set()
179        for f in self.modified_files:
180            if f.endswith(".rst"):
181                continue
182            d = os.path.dirname(f)
183            while d:
184                if os.path.exists(os.path.join(d, "testcase.yaml")) or \
185                    os.path.exists(os.path.join(d, "sample.yaml")):
186                    tests.add(d)
187                    break
188                else:
189                    d = os.path.dirname(d)
190
191        _options = []
192        for t in tests:
193            _options.extend(["-T", t ])
194
195        if _options:
196            logging.info(f'Potential test filters...')
197            if self.platforms:
198                for platform in self.platforms:
199                    _options.extend(["-p", platform])
200            else:
201                _options.append("--all")
202            self.get_plan(_options)
203
204    def find_tags(self):
205
206        tag_cfg_file = os.path.join(repository_path, 'scripts', 'ci', 'tags.yaml')
207        with open(tag_cfg_file, 'r') as ymlfile:
208            tags_config = yaml.safe_load(ymlfile)
209
210        tags = {}
211        for t,x in tags_config.items():
212            tag = Tag()
213            tag.exclude = True
214            tag.name = t
215
216            # tag._match_fn(path) tests if the path matches files and/or
217            # files-regex
218            tag._match_fn = _get_match_fn(x.get("files"), x.get("files-regex"))
219
220            # Like tag._match_fn(path), but for files-exclude and
221            # files-regex-exclude
222            tag._exclude_match_fn = \
223                _get_match_fn(x.get("files-exclude"), x.get("files-regex-exclude"))
224
225            tags[tag.name] = tag
226
227        for f in self.modified_files:
228            for t in tags.values():
229                if t._contains(f):
230                    t.exclude = False
231
232        exclude_tags = set()
233        for t in tags.values():
234            if t.exclude:
235                exclude_tags.add(t.name)
236
237        for tag in exclude_tags:
238            self.tag_options.extend(["-e", tag ])
239
240        if exclude_tags:
241            logging.info(f'Potential tag based filters...')
242
243    def find_excludes(self):
244        with open("scripts/ci/twister_ignore.txt", "r") as twister_ignore:
245            ignores = twister_ignore.read().splitlines()
246            ignores = filter(lambda x: not x.startswith("#"), ignores)
247
248        found = set()
249        files = list(filter(lambda x: x, self.modified_files))
250
251        for pattern in ignores:
252            if pattern:
253                found.update(fnmatch.filter(files, pattern))
254
255        logging.debug(found)
256        logging.debug(files)
257
258        if sorted(files) != sorted(found):
259            _options = []
260            logging.info(f'Need to run full or partial twister...')
261            self.full_twister = True
262            if self.platforms:
263                for platform in self.platforms:
264                    _options.extend(["-p", platform])
265
266                _options.extend(self.tag_options)
267                self.get_plan(_options)
268            else:
269                _options.extend(self.tag_options)
270                self.get_plan(_options, True)
271        else:
272            logging.info(f'No twister needed or partial twister run only...')
273
274def parse_args():
275    parser = argparse.ArgumentParser(
276                description="Generate twister argument files based on modified file")
277    parser.add_argument('-c', '--commits', default=None,
278            help="Commit range in the form: a..b")
279    parser.add_argument('-m', '--modified-files', default=None,
280            help="File with information about changed/deleted/added files.")
281    parser.add_argument('-o', '--output-file', default="testplan.csv",
282            help="CSV file with the test plan to be passed to twister")
283    parser.add_argument('-P', '--pull-request', action="store_true",
284            help="This is a pull request")
285    parser.add_argument('-p', '--platform', action="append",
286            help="Limit this for a platform or a list of platforms.")
287    parser.add_argument('-t', '--tests_per_builder', default=700, type=int,
288            help="Number of tests per builder")
289    parser.add_argument('-n', '--default-matrix', default=10, type=int,
290            help="Number of tests per builder")
291
292    return parser.parse_args()
293
294
295if __name__ == "__main__":
296
297    args = parse_args()
298    if args.commits:
299        repo = Repo(repository_path)
300        commit = repo.git.diff("--name-only", args.commits)
301        files = commit.split("\n")
302    elif args.modified_files:
303        with open(args.modified_files, "r") as fp:
304            files = json.load(fp)
305
306    print("Changed files:\n=========")
307    print("\n".join(files))
308    print("=========")
309
310    f = Filters(files, args.pull_request, args.platform)
311    f.process()
312
313    # remove dupes and filtered cases
314    dup_free = []
315    dup_free_set = set()
316    for x in f.all_tests:
317        if x[3] == 'skipped':
318            continue
319        if tuple(x) not in dup_free_set:
320            dup_free.append(x)
321            dup_free_set.add(tuple(x))
322
323    logging.info(f'Total tests to be run: {len(dup_free)}')
324    with open(".testplan", "w") as tp:
325        total_tests = len(dup_free)
326        nodes = round(total_tests / args.tests_per_builder)
327        if total_tests % args.tests_per_builder != total_tests:
328            nodes = nodes + 1
329
330        if nodes > 5:
331            nodes = args.default_matrix
332
333        tp.write(f"TWISTER_TESTS={total_tests}\n")
334        tp.write(f"TWISTER_NODES={nodes}\n")
335
336    header = ['test', 'arch', 'platform', 'status', 'extra_args', 'handler',
337            'handler_time', 'ram_size', 'rom_size']
338
339    # write plan
340    if dup_free:
341        with open(args.output_file, 'w', newline='') as csv_file:
342            writer = csv.writer(csv_file)
343            writer.writerow(header)
344            writer.writerows(dup_free)
345