1# vim: set syntax=python ts=4 :
2#
3# Copyright (c) 2018-2024 Intel Corporation
4# SPDX-License-Identifier: Apache-2.0
5
6import contextlib
7import glob
8import logging
9import mmap
10import os
11import re
12from enum import Enum
13from pathlib import Path
14
15from twisterlib.environment import canonical_zephyr_base
16from twisterlib.error import StatusAttributeError, TwisterException, TwisterRuntimeError
17from twisterlib.mixins import DisablePyTestCollectionMixin
18from twisterlib.statuses import TwisterStatus
19
20logger = logging.getLogger('twister')
21logger.setLevel(logging.DEBUG)
22
23class ScanPathResult:
24    """Result of the scan_tesuite_path function call.
25
26    Attributes:
27        matches                          A list of test cases
28        warnings                         A string containing one or more
29                                         warnings to display
30        has_registered_test_suites       Whether or not the path contained any
31                                         calls to the ztest_register_test_suite
32                                         macro.
33        has_run_registered_test_suites   Whether or not the path contained at
34                                         least one call to
35                                         ztest_run_registered_test_suites.
36        has_test_main                    Whether or not the path contains a
37                                         definition of test_main(void)
38        ztest_suite_names                Names of found ztest suites
39    """
40    def __init__(self,
41                 matches: list[str] = None,
42                 warnings: str = None,
43                 has_registered_test_suites: bool = False,
44                 has_run_registered_test_suites: bool = False,
45                 has_test_main: bool = False,
46                 ztest_suite_names: list[str] = None):
47        if ztest_suite_names is None:
48            ztest_suite_names = []
49        self.matches = matches
50        self.warnings = warnings
51        self.has_registered_test_suites = has_registered_test_suites
52        self.has_run_registered_test_suites = has_run_registered_test_suites
53        self.has_test_main = has_test_main
54        self.ztest_suite_names = ztest_suite_names
55
56    def __eq__(self, other):
57        if not isinstance(other, ScanPathResult):
58            return False
59        return (sorted(self.matches) == sorted(other.matches) and
60                self.warnings == other.warnings and
61                (self.has_registered_test_suites ==
62                 other.has_registered_test_suites) and
63                (self.has_run_registered_test_suites ==
64                 other.has_run_registered_test_suites) and
65                self.has_test_main == other.has_test_main and
66                (sorted(self.ztest_suite_names) ==
67                 sorted(other.ztest_suite_names)))
68
69def scan_file(inf_name):
70    regular_suite_regex = re.compile(
71        # do not match until end-of-line, otherwise we won't allow
72        # stc_regex below to catch the ones that are declared in the same
73        # line--as we only search starting the end of this match
74        br"^\s*ztest_test_suite\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*,",
75        re.MULTILINE)
76    registered_suite_regex = re.compile(
77        br"^\s*ztest_register_test_suite"
78        br"\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*,",
79        re.MULTILINE)
80    new_suite_regex = re.compile(
81        br"^\s*ZTEST_SUITE\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*,",
82        re.MULTILINE)
83    testcase_regex = re.compile(
84        br"^\s*(?:ZTEST|ZTEST_F|ZTEST_USER|ZTEST_USER_F)\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*,"
85        br"\s*(?P<testcase_name>[a-zA-Z0-9_]+)\s*",
86        re.MULTILINE)
87    # Checks if the file contains a definition of "void test_main(void)"
88    # Since ztest provides a plain test_main implementation it is OK to:
89    # 1. register test suites and not call the run function if and only if
90    #    the test doesn't have a custom test_main.
91    # 2. register test suites and a custom test_main definition if and only if
92    #    the test also calls ztest_run_registered_test_suites.
93    test_main_regex = re.compile(
94        br"^\s*void\s+test_main\(void\)",
95        re.MULTILINE)
96    registered_suite_run_regex = re.compile(
97        br"^\s*ztest_run_registered_test_suites\("
98        br"(\*+|&)?(?P<state_identifier>[a-zA-Z0-9_]+)\)",
99        re.MULTILINE)
100
101    warnings = None
102    has_registered_test_suites = False
103    has_run_registered_test_suites = False
104    has_test_main = False
105
106    with open(inf_name) as inf:
107        if os.name == 'nt':
108            mmap_args = {'fileno': inf.fileno(), 'length': 0, 'access': mmap.ACCESS_READ}
109        else:
110            mmap_args = {
111                'fileno': inf.fileno(),
112                'length': 0,
113                'flags': mmap.MAP_PRIVATE,
114                'prot': mmap.PROT_READ,
115                'offset': 0
116            }
117
118        with contextlib.closing(mmap.mmap(**mmap_args)) as main_c:
119            regular_suite_regex_matches = \
120                [m for m in regular_suite_regex.finditer(main_c)]
121            registered_suite_regex_matches = \
122                [m for m in registered_suite_regex.finditer(main_c)]
123            new_suite_testcase_regex_matches = \
124                [m for m in testcase_regex.finditer(main_c)]
125            new_suite_regex_matches = \
126                [m for m in new_suite_regex.finditer(main_c)]
127
128            if registered_suite_regex_matches:
129                has_registered_test_suites = True
130            if registered_suite_run_regex.search(main_c):
131                has_run_registered_test_suites = True
132            if test_main_regex.search(main_c):
133                has_test_main = True
134
135            if regular_suite_regex_matches:
136                ztest_suite_names = \
137                    _extract_ztest_suite_names(regular_suite_regex_matches)
138                testcase_names, warnings = _find_regular_ztest_testcases(
139                    main_c,
140                    regular_suite_regex_matches,
141                    has_registered_test_suites
142                )
143            elif registered_suite_regex_matches:
144                ztest_suite_names = \
145                    _extract_ztest_suite_names(registered_suite_regex_matches)
146                testcase_names, warnings = _find_regular_ztest_testcases(
147                    main_c,
148                    registered_suite_regex_matches,
149                    has_registered_test_suites
150                )
151            elif new_suite_regex_matches or new_suite_testcase_regex_matches:
152                ztest_suite_names = \
153                    _extract_ztest_suite_names(new_suite_regex_matches)
154                testcase_names, warnings = \
155                    _find_new_ztest_testcases(main_c)
156            else:
157                # can't find ztest_test_suite, maybe a client, because
158                # it includes ztest.h
159                ztest_suite_names = []
160                testcase_names, warnings = None, None
161
162            return ScanPathResult(
163                matches=testcase_names,
164                warnings=warnings,
165                has_registered_test_suites=has_registered_test_suites,
166                has_run_registered_test_suites=has_run_registered_test_suites,
167                has_test_main=has_test_main,
168                ztest_suite_names=ztest_suite_names)
169
170def _extract_ztest_suite_names(suite_regex_matches):
171    ztest_suite_names = \
172        [m.group("suite_name") for m in suite_regex_matches]
173    ztest_suite_names = \
174        [name.decode("UTF-8") for name in ztest_suite_names]
175    return ztest_suite_names
176
177def _find_regular_ztest_testcases(search_area, suite_regex_matches, is_registered_test_suite):
178    """
179    Find regular ztest testcases like "ztest_unit_test" or similar. Return
180    testcases' names and eventually found warnings.
181    """
182    testcase_regex = re.compile(
183        br"""^\s*  # empty space at the beginning is ok
184        # catch the case where it is declared in the same sentence, e.g:
185        #
186        # ztest_test_suite(mutex_complex, ztest_user_unit_test(TESTNAME));
187        # ztest_register_test_suite(n, p, ztest_user_unit_test(TESTNAME),
188        (?:ztest_
189            (?:test_suite\(|register_test_suite\([a-zA-Z0-9_]+\s*,\s*)
190            [a-zA-Z0-9_]+\s*,\s*
191        )?
192        # Catch ztest[_user]_unit_test-[_setup_teardown](TESTNAME)
193        ztest_(?:1cpu_)?(?:user_)?unit_test(?:_setup_teardown)?
194        # Consume the argument that becomes the extra testcase
195        \(\s*(?P<testcase_name>[a-zA-Z0-9_]+)
196        # _setup_teardown() variant has two extra arguments that we ignore
197        (?:\s*,\s*[a-zA-Z0-9_]+\s*,\s*[a-zA-Z0-9_]+)?
198        \s*\)""",
199        # We don't check how it finishes; we don't care
200        re.MULTILINE | re.VERBOSE)
201    achtung_regex = re.compile(
202        br"(#ifdef|#endif)",
203        re.MULTILINE)
204
205    search_start, search_end = \
206        _get_search_area_boundary(search_area, suite_regex_matches, is_registered_test_suite)
207    limited_search_area = search_area[search_start:search_end]
208    testcase_names, warnings = \
209        _find_ztest_testcases(limited_search_area, testcase_regex)
210
211    achtung_matches = re.findall(achtung_regex, limited_search_area)
212    if achtung_matches and warnings is None:
213        achtung = ", ".join(sorted({match.decode() for match in achtung_matches},reverse = True))
214        warnings = f"found invalid {achtung} in ztest_test_suite()"
215
216    return testcase_names, warnings
217
218
219def _get_search_area_boundary(search_area, suite_regex_matches, is_registered_test_suite):
220    """
221    Get search area boundary based on "ztest_test_suite(...)",
222    "ztest_register_test_suite(...)" or "ztest_run_test_suite(...)"
223    functions occurrence.
224    """
225    suite_run_regex = re.compile(
226        br"^\s*ztest_run_test_suite\((?P<suite_name>[a-zA-Z0-9_]+)\)",
227        re.MULTILINE)
228
229    search_start = suite_regex_matches[0].end()
230
231    suite_run_match = suite_run_regex.search(search_area)
232    if suite_run_match:
233        search_end = suite_run_match.start()
234    elif not suite_run_match and not is_registered_test_suite:
235        raise ValueError("can't find ztest_run_test_suite")
236    else:
237        search_end = re.compile(br"\);", re.MULTILINE) \
238            .search(search_area, search_start) \
239            .end()
240
241    return search_start, search_end
242
243def _find_new_ztest_testcases(search_area):
244    """
245    Find regular ztest testcases like "ZTEST", "ZTEST_F" etc. Return
246    testcases' names and eventually found warnings.
247    """
248    testcase_regex = re.compile(
249        br"^\s*(?:ZTEST|ZTEST_F|ZTEST_USER|ZTEST_USER_F)\(\s*(?P<suite_name>[a-zA-Z0-9_]+)\s*,"
250        br"\s*(?P<testcase_name>[a-zA-Z0-9_]+)\s*",
251        re.MULTILINE)
252
253    return _find_ztest_testcases(search_area, testcase_regex)
254
255def _find_ztest_testcases(search_area, testcase_regex):
256    """
257    Parse search area and try to find testcases defined in testcase_regex
258    argument. Return testcase names and eventually found warnings.
259    """
260    testcase_regex_matches = \
261        [m for m in testcase_regex.finditer(search_area)]
262    testcase_names = [
263        (
264            m.group("suite_name") if m.groupdict().get("suite_name") else b'',
265            m.group("testcase_name")
266        ) for m in testcase_regex_matches
267    ]
268    testcase_names = [
269        (ts_name.decode("UTF-8"), tc_name.decode("UTF-8")) for ts_name, tc_name in testcase_names
270    ]
271    warnings = None
272    for testcase_name in testcase_names:
273        if not testcase_name[1].startswith("test_"):
274            warnings = "Found a test that does not start with test_"
275    testcase_names = \
276        [(ts_name + '.' if ts_name else '') + f"{tc_name.replace('test_', '', 1)}" \
277         for (ts_name, tc_name) in testcase_names]
278
279    return testcase_names, warnings
280
281def find_c_files_in(path: str, extensions: list = None) -> list:
282    """
283    Find C or C++ sources in the directory specified by "path"
284    """
285    if extensions is None:
286        extensions = ['c', 'cpp', 'cxx', 'cc']
287    if not os.path.isdir(path):
288        return []
289
290    # back up previous CWD
291    oldpwd = os.getcwd()
292    os.chdir(path)
293
294    filenames = []
295    for ext in extensions:
296        # glob.glob('**/*.c') does not pick up the base directory
297        filenames += [os.path.join(path, x) for x in glob.glob(f'*.{ext}')]
298        # glob matches in subdirectories too
299        filenames += [os.path.join(path, x) for x in glob.glob(f'**/*.{ext}')]
300
301    # restore previous CWD
302    os.chdir(oldpwd)
303
304    return filenames
305
306def scan_testsuite_path(testsuite_path):
307    subcases = []
308    has_registered_test_suites = False
309    has_run_registered_test_suites = False
310    has_test_main = False
311    ztest_suite_names = []
312
313    src_dir_path = _find_src_dir_path(testsuite_path)
314    for filename in find_c_files_in(src_dir_path):
315        if os.stat(filename).st_size == 0:
316            continue
317        try:
318            result: ScanPathResult = scan_file(filename)
319            if result.warnings:
320                logger.error(f"{filename}: {result.warnings}")
321                raise TwisterRuntimeError(f"{filename}: {result.warnings}")
322            if result.matches:
323                subcases += result.matches
324            if result.has_registered_test_suites:
325                has_registered_test_suites = True
326            if result.has_run_registered_test_suites:
327                has_run_registered_test_suites = True
328            if result.has_test_main:
329                has_test_main = True
330            if result.ztest_suite_names:
331                ztest_suite_names += result.ztest_suite_names
332
333        except ValueError as e:
334            logger.error(f"{filename}: error parsing source file: {e}")
335
336    src_dir_pathlib_path = Path(src_dir_path)
337    for filename in find_c_files_in(testsuite_path):
338        # If we have already scanned those files in the src_dir step, skip them.
339        filename_path = Path(filename)
340        if src_dir_pathlib_path in filename_path.parents:
341            continue
342
343        try:
344            result: ScanPathResult = scan_file(filename)
345            if result.warnings:
346                logger.error(f"{filename}: {result.warnings}")
347            if result.matches:
348                subcases += result.matches
349            if result.ztest_suite_names:
350                ztest_suite_names += result.ztest_suite_names
351        except ValueError as e:
352            logger.error(f"{filename}: can't find: {e}")
353
354    if (has_registered_test_suites and has_test_main and
355            not has_run_registered_test_suites):
356        warning = \
357            "Found call to 'ztest_register_test_suite()' but no "\
358            "call to 'ztest_run_registered_test_suites()'"
359        logger.error(warning)
360        raise TwisterRuntimeError(warning)
361
362    return subcases, ztest_suite_names
363
364def _find_src_dir_path(test_dir_path):
365    """
366    Try to find src directory with test source code. Sometimes due to the
367    optimization reasons it is placed in upper directory.
368    """
369    src_dir_name = "src"
370    src_dir_path = os.path.join(test_dir_path, src_dir_name)
371    if os.path.isdir(src_dir_path):
372        return src_dir_path
373    src_dir_path = os.path.join(test_dir_path, "..", src_dir_name)
374    if os.path.isdir(src_dir_path):
375        return src_dir_path
376    return ""
377
378class TestCase(DisablePyTestCollectionMixin):
379
380    def __init__(self, name=None, testsuite=None):
381        self.duration = 0
382        self.name = name
383        self._status = TwisterStatus.NONE
384        self.reason = None
385        self.testsuite = testsuite
386        self.output = ""
387        self.freeform = False
388
389    @property
390    def detailed_name(self) -> str:
391        return TestSuite.get_case_name_(self.testsuite, self.name, detailed=True)
392
393    @property
394    def status(self) -> TwisterStatus:
395        return self._status
396
397    @status.setter
398    def status(self, value : TwisterStatus) -> None:
399        # Check for illegal assignments by value
400        try:
401            key = value.name if isinstance(value, Enum) else value
402            self._status = TwisterStatus[key]
403        except KeyError as err:
404            raise StatusAttributeError(self.__class__, value) from err
405
406    def __lt__(self, other):
407        return self.name < other.name
408
409    def __repr__(self):
410        return f"<TestCase {self.name} with {self.status}>"
411
412    def __str__(self):
413        return self.name
414
415class TestSuite(DisablePyTestCollectionMixin):
416    """Class representing a test application
417    """
418
419    def __init__(self, suite_root, suite_path, name, data=None, detailed_test_id=True):
420        """TestSuite constructor.
421
422        This gets called by TestPlan as it finds and reads test yaml files.
423        Multiple TestSuite instances may be generated from a single testcase.yaml,
424        each one corresponds to an entry within that file.
425
426        We need to have a unique name for every single test case. Since
427        a testcase.yaml can define multiple tests, the canonical name for
428        the test case is <workdir>/<name>.
429
430        @param testsuite_root os.path.abspath() of one of the --testsuite-root
431        @param suite_path path to testsuite
432        @param name Name of this test case, corresponding to the entry name
433            in the test case configuration file. For many test cases that just
434            define one test, can be anything and is usually "test". This is
435            really only used to distinguish between different cases when
436            the testcase.yaml defines multiple tests
437        """
438
439        workdir = os.path.relpath(suite_path, suite_root)
440
441        assert self.check_suite_name(name, suite_root, workdir)
442        self.detailed_test_id = detailed_test_id
443        self.name = self.get_unique(suite_root, workdir, name) if self.detailed_test_id else name
444        self.id = name
445
446        self.source_dir = suite_path
447        self.source_dir_rel = os.path.relpath(
448            os.path.realpath(suite_path), start=canonical_zephyr_base
449        )
450        self.yamlfile = suite_path
451        self.testcases = []
452        self.integration_platforms = []
453
454        self.ztest_suite_names = []
455
456        self._status = TwisterStatus.NONE
457
458        if data:
459            self.load(data)
460
461    @property
462    def status(self) -> TwisterStatus:
463        return self._status
464
465    @status.setter
466    def status(self, value : TwisterStatus) -> None:
467        # Check for illegal assignments by value
468        try:
469            key = value.name if isinstance(value, Enum) else value
470            self._status = TwisterStatus[key]
471        except KeyError as err:
472            raise StatusAttributeError(self.__class__, value) from err
473
474    def load(self, data):
475        for k, v in data.items():
476            if k != "testcases":
477                setattr(self, k, v)
478
479        if self.harness == 'console' and not self.harness_config:
480            raise Exception(
481                'Harness config error: console harness defined without a configuration.'
482            )
483
484    @staticmethod
485    def get_case_name_(test_suite, tc_name, detailed=True) -> str:
486        return f"{test_suite.id}.{tc_name}" \
487            if test_suite and detailed and not test_suite.detailed_test_id else f"{tc_name}"
488
489    @staticmethod
490    def compose_case_name_(test_suite, tc_name) -> str:
491        return f"{test_suite.id}.{tc_name}" \
492            if test_suite and test_suite.detailed_test_id else f"{tc_name}"
493
494    def compose_case_name(self, tc_name) -> str:
495        return self.compose_case_name_(self, tc_name)
496
497    def add_subcases(self, data, parsed_subcases=None, suite_names=None):
498        testcases = data.get("testcases", [])
499        if testcases:
500            for tc in testcases:
501                self.add_testcase(name=self.compose_case_name(tc))
502        else:
503            if not parsed_subcases:
504                self.add_testcase(self.id, freeform=True)
505            else:
506                # only add each testcase once
507                for tc in set(parsed_subcases):
508                    self.add_testcase(name=self.compose_case_name(tc))
509        if suite_names:
510            self.ztest_suite_names = suite_names
511
512    def add_testcase(self, name, freeform=False):
513        tc = TestCase(name=name, testsuite=self)
514        tc.freeform = freeform
515        self.testcases.append(tc)
516
517    @staticmethod
518    def get_unique(testsuite_root, workdir, name):
519
520        canonical_testsuite_root = os.path.realpath(testsuite_root)
521        if Path(canonical_zephyr_base) in Path(canonical_testsuite_root).parents:
522            # This is in ZEPHYR_BASE, so include path in name for uniqueness
523            # FIXME: We should not depend on path of test for unique names.
524            relative_ts_root = os.path.relpath(canonical_testsuite_root,
525                                               start=canonical_zephyr_base)
526        else:
527            relative_ts_root = ""
528
529        # workdir can be "."
530        unique = os.path.normpath(
531            os.path.join(relative_ts_root, workdir, name)
532        ).replace(os.sep, '/')
533        return unique
534
535    @staticmethod
536    def check_suite_name(name, testsuite_root, workdir):
537        check = name.split(".")
538        if len(check) < 2:
539            raise TwisterException(f"""bad test name '{name}' in {testsuite_root}/{workdir}. \
540Tests should reference the category and subsystem with a dot as a separator.
541                    """
542                    )
543        return True
544