1#!/usr/bin/env python3
2# SPDX-License-Identifier: Apache-2.0
3# Copyright (c) 2021 Intel Corporation
4
5# A script to generate twister options based on modified files.
6
7import re, os
8import argparse
9import yaml
10import fnmatch
11import subprocess
12import json
13import logging
14import sys
15from pathlib import Path
16from git import Repo
17from west.manifest import Manifest
18
19if "ZEPHYR_BASE" not in os.environ:
20    exit("$ZEPHYR_BASE environment variable undefined.")
21
22repository_path = Path(os.environ['ZEPHYR_BASE'])
23logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
24
25sys.path.append(os.path.join(repository_path, 'scripts'))
26import list_boards
27
28def _get_match_fn(globs, regexes):
29    # Constructs a single regex that tests for matches against the globs in
30    # 'globs' and the regexes in 'regexes'. Parts are joined with '|' (OR).
31    # Returns the search() method of the compiled regex.
32    #
33    # Returns None if there are neither globs nor regexes, which should be
34    # interpreted as no match.
35
36    if not (globs or regexes):
37        return None
38
39    regex = ""
40
41    if globs:
42        glob_regexes = []
43        for glob in globs:
44            # Construct a regex equivalent to the glob
45            glob_regex = glob.replace(".", "\\.").replace("*", "[^/]*") \
46                             .replace("?", "[^/]")
47
48            if not glob.endswith("/"):
49                # Require a full match for globs that don't end in /
50                glob_regex += "$"
51
52            glob_regexes.append(glob_regex)
53
54        # The glob regexes must anchor to the beginning of the path, since we
55        # return search(). (?:) is a non-capturing group.
56        regex += "^(?:{})".format("|".join(glob_regexes))
57
58    if regexes:
59        if regex:
60            regex += "|"
61        regex += "|".join(regexes)
62
63    return re.compile(regex).search
64
65class Tag:
66    """
67    Represents an entry for a tag in tags.yaml.
68
69    These attributes are available:
70
71    name:
72        List of GitHub labels for the area. Empty if the area has no 'labels'
73        key.
74
75    description:
76        Text from 'description' key, or None if the area has no 'description'
77        key
78    """
79    def _contains(self, path):
80        # Returns True if the area contains 'path', and False otherwise
81
82        return self._match_fn and self._match_fn(path) and not \
83            (self._exclude_match_fn and self._exclude_match_fn(path))
84
85    def __repr__(self):
86        return "<Tag {}>".format(self.name)
87
88class Filters:
89    def __init__(self, modified_files, pull_request=False, platforms=[]):
90        self.modified_files = modified_files
91        self.twister_options = []
92        self.full_twister = False
93        self.all_tests = []
94        self.tag_options = []
95        self.pull_request = pull_request
96        self.platforms = platforms
97        self.default_run = False
98
99    def process(self):
100        self.find_modules()
101        self.find_tags()
102        self.find_tests()
103        if not self.platforms:
104            self.find_archs()
105            self.find_boards()
106
107        if self.default_run:
108            self.find_excludes(skip=["tests/*", "boards/*/*/*"])
109        else:
110            self.find_excludes()
111
112    def get_plan(self, options, integration=False):
113        fname = "_test_plan_partial.json"
114        cmd = ["scripts/twister", "-c"] + options + ["--save-tests", fname ]
115        if integration:
116            cmd.append("--integration")
117
118        logging.info(" ".join(cmd))
119        _ = subprocess.call(cmd)
120        with open(fname, newline='') as jsonfile:
121            json_data = json.load(jsonfile)
122            suites = json_data.get("testsuites", [])
123            self.all_tests.extend(suites)
124        if os.path.exists(fname):
125            os.remove(fname)
126
127    def find_modules(self):
128        if 'west.yml' in self.modified_files:
129            print(f"Manifest file 'west.yml' changed")
130            print("=========")
131            old_manifest_content = repo.git.show(f"{args.commits[:-2]}:west.yml")
132            with open("west_old.yml", "w") as manifest:
133                manifest.write(old_manifest_content)
134            old_manifest = Manifest.from_file("west_old.yml")
135            new_manifest = Manifest.from_file("west.yml")
136            old_projs = set((p.name, p.revision) for p in old_manifest.projects)
137            new_projs = set((p.name, p.revision) for p in new_manifest.projects)
138            logging.debug(f'old_projs: {old_projs}')
139            logging.debug(f'new_projs: {new_projs}')
140            # Removed projects
141            rprojs = set(filter(lambda p: p[0] not in list(p[0] for p in new_projs),
142                old_projs - new_projs))
143            # Updated projects
144            uprojs = set(filter(lambda p: p[0] in list(p[0] for p in old_projs),
145                new_projs - old_projs))
146            # Added projects
147            aprojs = new_projs - old_projs - uprojs
148
149            # All projs
150            projs = rprojs | uprojs | aprojs
151            projs_names = [name for name, rev in projs]
152
153            logging.info(f'rprojs: {rprojs}')
154            logging.info(f'uprojs: {uprojs}')
155            logging.info(f'aprojs: {aprojs}')
156            logging.info(f'project: {projs_names}')
157
158            _options = []
159            for p in projs_names:
160                _options.extend(["-t", p ])
161
162            if self.platforms:
163                for platform in self.platforms:
164                    _options.extend(["-p", platform])
165
166            self.get_plan(_options, True)
167
168
169    def find_archs(self):
170        # we match both arch/<arch>/* and include/arch/<arch> and skip common.
171        # Some architectures like riscv require special handling, i.e. riscv
172        # directory covers 2 architectures known to twister: riscv32 and riscv64.
173        archs = set()
174
175        for f in self.modified_files:
176            p = re.match(r"^arch\/([^/]+)\/", f)
177            if not p:
178                p = re.match(r"^include\/arch\/([^/]+)\/", f)
179            if p:
180                if p.group(1) != 'common':
181                    if p.group(1) == 'riscv':
182                        archs.add('riscv32')
183                        archs.add('riscv64')
184                    else:
185                        archs.add(p.group(1))
186
187        _options = []
188        for arch in archs:
189            _options.extend(["-a", arch ])
190
191        if _options:
192            logging.info(f'Potential architecture filters...')
193            if self.platforms:
194                for platform in self.platforms:
195                    _options.extend(["-p", platform])
196
197                self.get_plan(_options, True)
198            else:
199                self.get_plan(_options, False)
200
201    def find_boards(self):
202        boards = set()
203        all_boards = set()
204
205        for f in self.modified_files:
206            if f.endswith(".rst") or f.endswith(".png") or f.endswith(".jpg"):
207                continue
208            p = re.match(r"^boards\/[^/]+\/([^/]+)\/", f)
209            if p and p.groups():
210                boards.add(p.group(1))
211
212        # Limit search to $ZEPHYR_BASE since this is where the changed files are
213        lb_args = argparse.Namespace(**{ 'arch_roots': [repository_path], 'board_roots': [repository_path] })
214        known_boards = list_boards.find_boards(lb_args)
215        for b in boards:
216            name_re = re.compile(b)
217            for kb in known_boards:
218                if name_re.search(kb.name):
219                    all_boards.add(kb.name)
220
221        _options = []
222        if len(all_boards) > 20:
223            logging.warning(f"{len(boards)} boards changed, this looks like a global change, skipping test handling, revert to default.")
224            self.default_run = True
225            return
226
227        for board in all_boards:
228            _options.extend(["-p", board ])
229
230        if _options:
231            logging.info(f'Potential board filters...')
232            self.get_plan(_options)
233
234    def find_tests(self):
235        tests = set()
236        for f in self.modified_files:
237            if f.endswith(".rst"):
238                continue
239            d = os.path.dirname(f)
240            while d:
241                if os.path.exists(os.path.join(d, "testcase.yaml")) or \
242                    os.path.exists(os.path.join(d, "sample.yaml")):
243                    tests.add(d)
244                    break
245                else:
246                    d = os.path.dirname(d)
247
248        _options = []
249        for t in tests:
250            _options.extend(["-T", t ])
251
252        if len(tests) > 20:
253            logging.warning(f"{len(tests)} tests changed, this looks like a global change, skipping test handling, revert to default")
254            self.default_run = True
255            return
256
257        if _options:
258            logging.info(f'Potential test filters...({len(tests)} changed...)')
259            if self.platforms:
260                for platform in self.platforms:
261                    _options.extend(["-p", platform])
262            else:
263                _options.append("--all")
264            self.get_plan(_options)
265
266    def find_tags(self):
267
268        tag_cfg_file = os.path.join(repository_path, 'scripts', 'ci', 'tags.yaml')
269        with open(tag_cfg_file, 'r') as ymlfile:
270            tags_config = yaml.safe_load(ymlfile)
271
272        tags = {}
273        for t,x in tags_config.items():
274            tag = Tag()
275            tag.exclude = True
276            tag.name = t
277
278            # tag._match_fn(path) tests if the path matches files and/or
279            # files-regex
280            tag._match_fn = _get_match_fn(x.get("files"), x.get("files-regex"))
281
282            # Like tag._match_fn(path), but for files-exclude and
283            # files-regex-exclude
284            tag._exclude_match_fn = \
285                _get_match_fn(x.get("files-exclude"), x.get("files-regex-exclude"))
286
287            tags[tag.name] = tag
288
289        for f in self.modified_files:
290            for t in tags.values():
291                if t._contains(f):
292                    t.exclude = False
293
294        exclude_tags = set()
295        for t in tags.values():
296            if t.exclude:
297                exclude_tags.add(t.name)
298
299        for tag in exclude_tags:
300            self.tag_options.extend(["-e", tag ])
301
302        if exclude_tags:
303            logging.info(f'Potential tag based filters: {exclude_tags}')
304
305    def find_excludes(self, skip=[]):
306        with open("scripts/ci/twister_ignore.txt", "r") as twister_ignore:
307            ignores = twister_ignore.read().splitlines()
308            ignores = filter(lambda x: not x.startswith("#"), ignores)
309
310        found = set()
311        files = list(filter(lambda x: x, self.modified_files))
312
313        for pattern in ignores:
314            if pattern in skip:
315                continue
316            if pattern:
317                found.update(fnmatch.filter(files, pattern))
318
319        logging.debug(found)
320        logging.debug(files)
321
322        if sorted(files) != sorted(found):
323            _options = []
324            logging.info(f'Need to run full or partial twister...')
325            self.full_twister = True
326            if self.platforms:
327                for platform in self.platforms:
328                    _options.extend(["-p", platform])
329
330                _options.extend(self.tag_options)
331                self.get_plan(_options)
332            else:
333                _options.extend(self.tag_options)
334                self.get_plan(_options, True)
335        else:
336            logging.info(f'No twister needed or partial twister run only...')
337
338def parse_args():
339    parser = argparse.ArgumentParser(
340                description="Generate twister argument files based on modified file",
341                allow_abbrev=False)
342    parser.add_argument('-c', '--commits', default=None,
343            help="Commit range in the form: a..b")
344    parser.add_argument('-m', '--modified-files', default=None,
345            help="File with information about changed/deleted/added files.")
346    parser.add_argument('-o', '--output-file', default="testplan.json",
347            help="JSON file with the test plan to be passed to twister")
348    parser.add_argument('-P', '--pull-request', action="store_true",
349            help="This is a pull request")
350    parser.add_argument('-p', '--platform', action="append",
351            help="Limit this for a platform or a list of platforms.")
352    parser.add_argument('-t', '--tests_per_builder', default=700, type=int,
353            help="Number of tests per builder")
354    parser.add_argument('-n', '--default-matrix', default=10, type=int,
355            help="Number of tests per builder")
356
357    return parser.parse_args()
358
359
360if __name__ == "__main__":
361
362    args = parse_args()
363    files = []
364    errors = 0
365    if args.commits:
366        repo = Repo(repository_path)
367        commit = repo.git.diff("--name-only", args.commits)
368        files = commit.split("\n")
369    elif args.modified_files:
370        with open(args.modified_files, "r") as fp:
371            files = json.load(fp)
372
373    if files:
374        print("Changed files:\n=========")
375        print("\n".join(files))
376        print("=========")
377
378
379    f = Filters(files, args.pull_request, args.platform)
380    f.process()
381
382    # remove dupes and filtered cases
383    dup_free = []
384    dup_free_set = set()
385    logging.info(f'Total tests gathered: {len(f.all_tests)}')
386    for ts in f.all_tests:
387        if ts.get('status') == 'filtered':
388            continue
389        n = ts.get("name")
390        a = ts.get("arch")
391        p = ts.get("platform")
392        if ts.get('status') == 'error':
393            logging.info(f"Error found: {n} on {p} ({ts.get('reason')})")
394            errors += 1
395        if (n, a, p,) not in dup_free_set:
396            dup_free.append(ts)
397            dup_free_set.add((n, a, p,))
398
399    logging.info(f'Total tests to be run: {len(dup_free)}')
400    with open(".testplan", "w") as tp:
401        total_tests = len(dup_free)
402        if total_tests and total_tests < args.tests_per_builder:
403            nodes = 1
404        else:
405            nodes = round(total_tests / args.tests_per_builder)
406
407        tp.write(f"TWISTER_TESTS={total_tests}\n")
408        tp.write(f"TWISTER_NODES={nodes}\n")
409        tp.write(f"TWISTER_FULL={f.full_twister}\n")
410        logging.info(f'Total nodes to launch: {nodes}')
411
412    header = ['test', 'arch', 'platform', 'status', 'extra_args', 'handler',
413            'handler_time', 'used_ram', 'used_rom']
414
415    # write plan
416    if dup_free:
417        data = {}
418        data['testsuites'] = dup_free
419        with open(args.output_file, 'w', newline='') as json_file:
420            json.dump(data, json_file, indent=4, separators=(',',':'))
421
422    sys.exit(errors)
423