1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# Copyright (c) 2021 Intel Corporation
4
5# A script to generate twister options based on modified files.
6
7import re, os
8import argparse
9import yaml
10import fnmatch
11import subprocess
12import json
13import logging
14import sys
15from pathlib import Path
16from git import Repo
17
18if "ZEPHYR_BASE" not in os.environ:
19    exit("$ZEPHYR_BASE environment variable undefined.")
20
21repository_path = Path(os.environ['ZEPHYR_BASE'])
22logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
23
24sys.path.append(os.path.join(repository_path, 'scripts'))
25import list_boards
26
27def _get_match_fn(globs, regexes):
28    # Constructs a single regex that tests for matches against the globs in
29    # 'globs' and the regexes in 'regexes'. Parts are joined with '|' (OR).
30    # Returns the search() method of the compiled regex.
31    #
32    # Returns None if there are neither globs nor regexes, which should be
33    # interpreted as no match.
34
35    if not (globs or regexes):
36        return None
37
38    regex = ""
39
40    if globs:
41        glob_regexes = []
42        for glob in globs:
43            # Construct a regex equivalent to the glob
44            glob_regex = glob.replace(".", "\\.").replace("*", "[^/]*") \
45                             .replace("?", "[^/]")
46
47            if not glob.endswith("/"):
48                # Require a full match for globs that don't end in /
49                glob_regex += "$"
50
51            glob_regexes.append(glob_regex)
52
53        # The glob regexes must anchor to the beginning of the path, since we
54        # return search(). (?:) is a non-capturing group.
55        regex += "^(?:{})".format("|".join(glob_regexes))
56
57    if regexes:
58        if regex:
59            regex += "|"
60        regex += "|".join(regexes)
61
62    return re.compile(regex).search
63
64class Tag:
65    """
66    Represents an entry for a tag in tags.yaml.
67
68    These attributes are available:
69
70    name:
71        List of GitHub labels for the area. Empty if the area has no 'labels'
72        key.
73
74    description:
75        Text from 'description' key, or None if the area has no 'description'
76        key
77    """
78    def _contains(self, path):
79        # Returns True if the area contains 'path', and False otherwise
80
81        return self._match_fn and self._match_fn(path) and not \
82            (self._exclude_match_fn and self._exclude_match_fn(path))
83
84    def __repr__(self):
85        return "<Tag {}>".format(self.name)
86
87class Filters:
88    def __init__(self, modified_files, pull_request=False, platforms=[]):
89        self.modified_files = modified_files
90        self.twister_options = []
91        self.full_twister = False
92        self.all_tests = []
93        self.tag_options = []
94        self.pull_request = pull_request
95        self.platforms = platforms
96        self.default_run = False
97
98    def process(self):
99        self.find_tags()
100        self.find_tests()
101        if not self.platforms:
102            self.find_archs()
103            self.find_boards()
104
105        if self.default_run:
106            self.find_excludes(skip=["tests/*", "boards/*/*/*"])
107        else:
108            self.find_excludes()
109
110    def get_plan(self, options, integration=False):
111        fname = "_test_plan_partial.json"
112        cmd = ["scripts/twister", "-c"] + options + ["--save-tests", fname ]
113        if integration:
114            cmd.append("--integration")
115
116        logging.info(" ".join(cmd))
117        _ = subprocess.call(cmd)
118        with open(fname, newline='') as jsonfile:
119            json_data = json.load(jsonfile)
120            suites = json_data.get("testsuites", [])
121            self.all_tests.extend(suites)
122        if os.path.exists(fname):
123            os.remove(fname)
124
125    def find_archs(self):
126        # we match both arch/<arch>/* and include/arch/<arch> and skip common.
127        # Some architectures like riscv require special handling, i.e. riscv
128        # directory covers 2 architectures known to twister: riscv32 and riscv64.
129        archs = set()
130
131        for f in self.modified_files:
132            p = re.match(r"^arch\/([^/]+)\/", f)
133            if not p:
134                p = re.match(r"^include\/arch\/([^/]+)\/", f)
135            if p:
136                if p.group(1) != 'common':
137                    if p.group(1) == 'riscv':
138                        archs.add('riscv32')
139                        archs.add('riscv64')
140                    else:
141                        archs.add(p.group(1))
142
143        _options = []
144        for arch in archs:
145            _options.extend(["-a", arch ])
146
147        if _options:
148            logging.info(f'Potential architecture filters...')
149            if self.platforms:
150                for platform in self.platforms:
151                    _options.extend(["-p", platform])
152
153                self.get_plan(_options, True)
154            else:
155                self.get_plan(_options, False)
156
157    def find_boards(self):
158        boards = set()
159        all_boards = set()
160
161        for f in self.modified_files:
162            if f.endswith(".rst") or f.endswith(".png") or f.endswith(".jpg"):
163                continue
164            p = re.match(r"^boards\/[^/]+\/([^/]+)\/", f)
165            if p and p.groups():
166                boards.add(p.group(1))
167
168        # Limit search to $ZEPHYR_BASE since this is where the changed files are
169        lb_args = argparse.Namespace(**{ 'arch_roots': [repository_path], 'board_roots': [repository_path] })
170        known_boards = list_boards.find_boards(lb_args)
171        for b in boards:
172            name_re = re.compile(b)
173            for kb in known_boards:
174                if name_re.search(kb.name):
175                    all_boards.add(kb.name)
176
177        _options = []
178        if len(all_boards) > 20:
179            logging.warning(f"{len(boards)} boards changed, this looks like a global change, skipping test handling, revert to default.")
180            self.default_run = True
181            return
182
183        for board in all_boards:
184            _options.extend(["-p", board ])
185
186        if _options:
187            logging.info(f'Potential board filters...')
188            self.get_plan(_options)
189
190    def find_tests(self):
191        tests = set()
192        for f in self.modified_files:
193            if f.endswith(".rst"):
194                continue
195            d = os.path.dirname(f)
196            while d:
197                if os.path.exists(os.path.join(d, "testcase.yaml")) or \
198                    os.path.exists(os.path.join(d, "sample.yaml")):
199                    tests.add(d)
200                    break
201                else:
202                    d = os.path.dirname(d)
203
204        _options = []
205        for t in tests:
206            _options.extend(["-T", t ])
207
208        if len(tests) > 20:
209            logging.warning(f"{len(tests)} tests changed, this looks like a global change, skipping test handling, revert to default")
210            self.default_run = True
211            return
212
213        if _options:
214            logging.info(f'Potential test filters...({len(tests)} changed...)')
215            if self.platforms:
216                for platform in self.platforms:
217                    _options.extend(["-p", platform])
218            else:
219                _options.append("--all")
220            self.get_plan(_options)
221
222    def find_tags(self):
223
224        tag_cfg_file = os.path.join(repository_path, 'scripts', 'ci', 'tags.yaml')
225        with open(tag_cfg_file, 'r') as ymlfile:
226            tags_config = yaml.safe_load(ymlfile)
227
228        tags = {}
229        for t,x in tags_config.items():
230            tag = Tag()
231            tag.exclude = True
232            tag.name = t
233
234            # tag._match_fn(path) tests if the path matches files and/or
235            # files-regex
236            tag._match_fn = _get_match_fn(x.get("files"), x.get("files-regex"))
237
238            # Like tag._match_fn(path), but for files-exclude and
239            # files-regex-exclude
240            tag._exclude_match_fn = \
241                _get_match_fn(x.get("files-exclude"), x.get("files-regex-exclude"))
242
243            tags[tag.name] = tag
244
245        for f in self.modified_files:
246            for t in tags.values():
247                if t._contains(f):
248                    t.exclude = False
249
250        exclude_tags = set()
251        for t in tags.values():
252            if t.exclude:
253                exclude_tags.add(t.name)
254
255        for tag in exclude_tags:
256            self.tag_options.extend(["-e", tag ])
257
258        if exclude_tags:
259            logging.info(f'Potential tag based filters: {exclude_tags}')
260
261    def find_excludes(self, skip=[]):
262        with open("scripts/ci/twister_ignore.txt", "r") as twister_ignore:
263            ignores = twister_ignore.read().splitlines()
264            ignores = filter(lambda x: not x.startswith("#"), ignores)
265
266        found = set()
267        files = list(filter(lambda x: x, self.modified_files))
268
269        for pattern in ignores:
270            if pattern in skip:
271                continue
272            if pattern:
273                found.update(fnmatch.filter(files, pattern))
274
275        logging.debug(found)
276        logging.debug(files)
277
278        if sorted(files) != sorted(found):
279            _options = []
280            logging.info(f'Need to run full or partial twister...')
281            self.full_twister = True
282            if self.platforms:
283                for platform in self.platforms:
284                    _options.extend(["-p", platform])
285
286                _options.extend(self.tag_options)
287                self.get_plan(_options)
288            else:
289                _options.extend(self.tag_options)
290                self.get_plan(_options, True)
291        else:
292            logging.info(f'No twister needed or partial twister run only...')
293
294def parse_args():
295    parser = argparse.ArgumentParser(
296                description="Generate twister argument files based on modified file",
297                allow_abbrev=False)
298    parser.add_argument('-c', '--commits', default=None,
299            help="Commit range in the form: a..b")
300    parser.add_argument('-m', '--modified-files', default=None,
301            help="File with information about changed/deleted/added files.")
302    parser.add_argument('-o', '--output-file', default="testplan.json",
303            help="JSON file with the test plan to be passed to twister")
304    parser.add_argument('-P', '--pull-request', action="store_true",
305            help="This is a pull request")
306    parser.add_argument('-p', '--platform', action="append",
307            help="Limit this for a platform or a list of platforms.")
308    parser.add_argument('-t', '--tests_per_builder', default=700, type=int,
309            help="Number of tests per builder")
310    parser.add_argument('-n', '--default-matrix', default=10, type=int,
311            help="Number of tests per builder")
312
313    return parser.parse_args()
314
315
316if __name__ == "__main__":
317
318    args = parse_args()
319    files = []
320    errors = 0
321    if args.commits:
322        repo = Repo(repository_path)
323        commit = repo.git.diff("--name-only", args.commits)
324        files = commit.split("\n")
325    elif args.modified_files:
326        with open(args.modified_files, "r") as fp:
327            files = json.load(fp)
328
329    if files:
330        print("Changed files:\n=========")
331        print("\n".join(files))
332        print("=========")
333
334    f = Filters(files, args.pull_request, args.platform)
335    f.process()
336
337    # remove dupes and filtered cases
338    dup_free = []
339    dup_free_set = set()
340    logging.info(f'Total tests gathered: {len(f.all_tests)}')
341    for ts in f.all_tests:
342        if ts.get('status') == 'filtered':
343            continue
344        n = ts.get("name")
345        a = ts.get("arch")
346        p = ts.get("platform")
347        if ts.get('status') == 'error':
348            logging.info(f"Error found: {n} on {p} ({ts.get('reason')})")
349            errors += 1
350        if (n, a, p,) not in dup_free_set:
351            dup_free.append(ts)
352            dup_free_set.add((n, a, p,))
353
354    logging.info(f'Total tests to be run: {len(dup_free)}')
355    with open(".testplan", "w") as tp:
356        total_tests = len(dup_free)
357        if total_tests and total_tests < args.tests_per_builder:
358            nodes = 1
359        else:
360            nodes = round(total_tests / args.tests_per_builder)
361
362        tp.write(f"TWISTER_TESTS={total_tests}\n")
363        tp.write(f"TWISTER_NODES={nodes}\n")
364        tp.write(f"TWISTER_FULL={f.full_twister}\n")
365        logging.info(f'Total nodes to launch: {nodes}')
366
367    header = ['test', 'arch', 'platform', 'status', 'extra_args', 'handler',
368            'handler_time', 'used_ram', 'used_rom']
369
370    # write plan
371    if dup_free:
372        data = {}
373        data['testsuites'] = dup_free
374        with open(args.output_file, 'w', newline='') as json_file:
375            json.dump(data, json_file, indent=4, separators=(',',':'))
376
377    sys.exit(errors)
378