1"""
2Command line tool to assign tests to CI test jobs.
3"""
4import argparse
5import errno
6import json
7import os
8import re
9
10import yaml
11
12try:
13    from yaml import CLoader as Loader
14except ImportError:
15    from yaml import Loader as Loader  # type: ignore
16
17import gitlab_api
18from tiny_test_fw.Utility import CIAssignTest
19
20try:
21    from idf_py_actions.constants import PREVIEW_TARGETS, SUPPORTED_TARGETS
22except ImportError:
23    SUPPORTED_TARGETS = []
24    PREVIEW_TARGETS = []
25
26IDF_PATH_FROM_ENV = os.getenv('IDF_PATH')
27
28
29class IDFCaseGroup(CIAssignTest.Group):
30    LOCAL_BUILD_DIR = None
31    BUILD_JOB_NAMES = None
32
33    @classmethod
34    def get_artifact_index_file(cls):
35        assert cls.LOCAL_BUILD_DIR
36        if IDF_PATH_FROM_ENV:
37            artifact_index_file = os.path.join(IDF_PATH_FROM_ENV, cls.LOCAL_BUILD_DIR, 'artifact_index.json')
38        else:
39            artifact_index_file = 'artifact_index.json'
40        return artifact_index_file
41
42
43class IDFAssignTest(CIAssignTest.AssignTest):
44    def __init__(self, test_case_path, ci_config_file, case_group=IDFCaseGroup):
45        super(IDFAssignTest, self).__init__(test_case_path, ci_config_file, case_group)
46
47    def format_build_log_path(self, parallel_num):
48        return '{}/list_job_{}.json'.format(self.case_group.LOCAL_BUILD_DIR, parallel_num)
49
50    def create_artifact_index_file(self, project_id=None, pipeline_id=None):
51        if project_id is None:
52            project_id = os.getenv('CI_PROJECT_ID')
53        if pipeline_id is None:
54            pipeline_id = os.getenv('CI_PIPELINE_ID')
55        gitlab_inst = gitlab_api.Gitlab(project_id)
56
57        artifact_index_list = []
58        for build_job_name in self.case_group.BUILD_JOB_NAMES:
59            job_info_list = gitlab_inst.find_job_id(build_job_name, pipeline_id=pipeline_id)
60            for job_info in job_info_list:
61                parallel_num = job_info['parallel_num'] or 1  # Could be None if "parallel_num" not defined for the job
62                raw_data = gitlab_inst.download_artifact(job_info['id'],
63                                                         [self.format_build_log_path(parallel_num)])[0]
64                build_info_list = [json.loads(line) for line in raw_data.decode().splitlines()]
65                for build_info in build_info_list:
66                    build_info['ci_job_id'] = job_info['id']
67                    artifact_index_list.append(build_info)
68        artifact_index_file = self.case_group.get_artifact_index_file()
69        try:
70            os.makedirs(os.path.dirname(artifact_index_file))
71        except OSError as e:
72            if e.errno != errno.EEXIST:
73                raise e
74
75        with open(artifact_index_file, 'w') as f:
76            json.dump(artifact_index_list, f)
77
78
79class ExampleGroup(IDFCaseGroup):
80    SORT_KEYS = CI_JOB_MATCH_KEYS = ['env_tag', 'target']
81
82    LOCAL_BUILD_DIR = 'build_examples'  # type: ignore
83    EXAMPLE_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
84    BUILD_JOB_NAMES = ['build_examples_cmake_{}'.format(target) for target in EXAMPLE_TARGETS]  # type: ignore
85
86
87class TestAppsGroup(ExampleGroup):
88    LOCAL_BUILD_DIR = 'build_test_apps'
89    TEST_APP_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
90    BUILD_JOB_NAMES = ['build_test_apps_{}'.format(target) for target in TEST_APP_TARGETS]  # type: ignore
91
92
93class ComponentUTGroup(TestAppsGroup):
94    LOCAL_BUILD_DIR = 'build_component_ut'
95    UNIT_TEST_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
96    BUILD_JOB_NAMES = ['build_component_ut_{}'.format(target) for target in UNIT_TEST_TARGETS]  # type: ignore
97
98
99class UnitTestGroup(IDFCaseGroup):
100    SORT_KEYS = ['test environment', 'tags', 'chip_target']
101    CI_JOB_MATCH_KEYS = ['test environment']
102
103    LOCAL_BUILD_DIR = 'tools/unit-test-app/builds'  # type: ignore
104    UNIT_TEST_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
105    BUILD_JOB_NAMES = ['build_esp_idf_tests_cmake_{}'.format(target) for target in UNIT_TEST_TARGETS]  # type: ignore
106
107    MAX_CASE = 50
108    ATTR_CONVERT_TABLE = {
109        'execution_time': 'execution time'
110    }
111    DUT_CLS_NAME = {
112        'esp32': 'ESP32DUT',
113        'esp32s2': 'ESP32S2DUT',
114        'esp32s3': 'ESP32S3DUT',
115        'esp32c3': 'ESP32C3DUT',
116        'esp8266': 'ESP8266DUT',
117    }
118
119    def __init__(self, case):
120        super(UnitTestGroup, self).__init__(case)
121        for tag in self._get_case_attr(case, 'tags'):
122            self.ci_job_match_keys.add(tag)
123
124    @staticmethod
125    def _get_case_attr(case, attr):
126        if attr in UnitTestGroup.ATTR_CONVERT_TABLE:
127            attr = UnitTestGroup.ATTR_CONVERT_TABLE[attr]
128        return case[attr]
129
130    def add_extra_case(self, case):
131        """ If current group contains all tags required by case, then add succeed """
132        added = False
133        if self.accept_new_case():
134            for key in self.filters:
135                if self._get_case_attr(case, key) != self.filters[key]:
136                    if key == 'tags':
137                        if set(self._get_case_attr(case, key)).issubset(set(self.filters[key])):
138                            continue
139                    break
140            else:
141                self.case_list.append(case)
142                added = True
143        return added
144
145    def _create_extra_data(self, test_cases, test_function):
146        """
147        For unit test case, we need to copy some attributes of test cases into config file.
148        So unit test function knows how to run the case.
149        """
150        case_data = []
151        for case in test_cases:
152            one_case_data = {
153                'config': self._get_case_attr(case, 'config'),
154                'name': self._get_case_attr(case, 'summary'),
155                'reset': self._get_case_attr(case, 'reset'),
156                'timeout': self._get_case_attr(case, 'timeout'),
157            }
158
159            if test_function in ['run_multiple_devices_cases', 'run_multiple_stage_cases']:
160                try:
161                    one_case_data['child case num'] = self._get_case_attr(case, 'child case num')
162                except KeyError as e:
163                    print('multiple devices/stages cases must contains at least two test functions')
164                    print('case name: {}'.format(one_case_data['name']))
165                    raise e
166
167            case_data.append(one_case_data)
168        return case_data
169
170    def _divide_case_by_test_function(self):
171        """
172        divide cases of current test group by test function they need to use
173
174        :return: dict of list of cases for each test functions
175        """
176        case_by_test_function = {
177            'run_multiple_devices_cases': [],
178            'run_multiple_stage_cases': [],
179            'run_unit_test_cases': [],
180        }
181
182        for case in self.case_list:
183            if case['multi_device'] == 'Yes':
184                case_by_test_function['run_multiple_devices_cases'].append(case)
185            elif case['multi_stage'] == 'Yes':
186                case_by_test_function['run_multiple_stage_cases'].append(case)
187            else:
188                case_by_test_function['run_unit_test_cases'].append(case)
189        return case_by_test_function
190
191    def output(self):
192        """
193        output data for job configs
194
195        :return: {"Filter": case filter, "CaseConfig": list of case configs for cases in this group}
196        """
197
198        target = self._get_case_attr(self.case_list[0], 'chip_target')
199        if target:
200            overwrite = {
201                'dut': {
202                    'package': 'ttfw_idf',
203                    'class': self.DUT_CLS_NAME[target],
204                }
205            }
206        else:
207            overwrite = dict()
208
209        case_by_test_function = self._divide_case_by_test_function()
210
211        output_data = {
212            # we don't need filter for test function, as UT uses a few test functions for all cases
213            'CaseConfig': [
214                {
215                    'name': test_function,
216                    'extra_data': self._create_extra_data(test_cases, test_function),
217                    'overwrite': overwrite,
218                } for test_function, test_cases in case_by_test_function.items() if test_cases
219            ],
220        }
221        return output_data
222
223
224class ExampleAssignTest(IDFAssignTest):
225    CI_TEST_JOB_PATTERN = re.compile(r'^example_test_.+')
226
227    def __init__(self, test_case_path, ci_config_file):
228        super(ExampleAssignTest, self).__init__(test_case_path, ci_config_file, case_group=ExampleGroup)
229
230
231class TestAppsAssignTest(IDFAssignTest):
232    CI_TEST_JOB_PATTERN = re.compile(r'^test_app_test_.+')
233
234    def __init__(self, test_case_path, ci_config_file):
235        super(TestAppsAssignTest, self).__init__(test_case_path, ci_config_file, case_group=TestAppsGroup)
236
237
238class ComponentUTAssignTest(IDFAssignTest):
239    CI_TEST_JOB_PATTERN = re.compile(r'^component_ut_test_.+')
240
241    def __init__(self, test_case_path, ci_config_file):
242        super(ComponentUTAssignTest, self).__init__(test_case_path, ci_config_file, case_group=ComponentUTGroup)
243
244
245class UnitTestAssignTest(IDFAssignTest):
246    CI_TEST_JOB_PATTERN = re.compile(r'^UT_.+')
247
248    def __init__(self, test_case_path, ci_config_file):
249        super(UnitTestAssignTest, self).__init__(test_case_path, ci_config_file, case_group=UnitTestGroup)
250
251    def search_cases(self, case_filter=None):
252        """
253        For unit test case, we don't search for test functions.
254        The unit test cases is stored in a yaml file which is created in job build-idf-test.
255        """
256
257        def find_by_suffix(suffix, path):
258            res = []
259            for root, _, files in os.walk(path):
260                for file in files:
261                    if file.endswith(suffix):
262                        res.append(os.path.join(root, file))
263            return res
264
265        def get_test_cases_from_yml(yml_file):
266            try:
267                with open(yml_file) as fr:
268                    raw_data = yaml.load(fr, Loader=Loader)
269                test_cases = raw_data['test cases']
270            except (IOError, KeyError):
271                return []
272            else:
273                return test_cases
274
275        test_cases = []
276        for path in self.test_case_paths:
277            if os.path.isdir(path):
278                for yml_file in find_by_suffix('.yml', path):
279                    test_cases.extend(get_test_cases_from_yml(yml_file))
280            elif os.path.isfile(path) and path.endswith('.yml'):
281                test_cases.extend(get_test_cases_from_yml(path))
282            else:
283                print('Test case path is invalid. Should only happen when use @bot to skip unit test.')
284
285        # filter keys are lower case. Do map lower case keys with original keys.
286        try:
287            key_mapping = {x.lower(): x for x in test_cases[0].keys()}
288        except IndexError:
289            key_mapping = dict()
290        if case_filter:
291            for key in case_filter:
292                filtered_cases = []
293                for case in test_cases:
294                    try:
295                        mapped_key = key_mapping[key]
296                        # bot converts string to lower case
297                        if isinstance(case[mapped_key], str):
298                            _value = case[mapped_key].lower()
299                        else:
300                            _value = case[mapped_key]
301                        if _value in case_filter[key]:
302                            filtered_cases.append(case)
303                    except KeyError:
304                        # case don't have this key, regard as filter success
305                        filtered_cases.append(case)
306                test_cases = filtered_cases
307        # sort cases with configs and test functions
308        # in later stage cases with similar attributes are more likely to be assigned to the same job
309        # it will reduce the count of flash DUT operations
310        test_cases.sort(key=lambda x: x['config'] + x['multi_stage'] + x['multi_device'])
311        return test_cases
312
313
314if __name__ == '__main__':
315    parser = argparse.ArgumentParser()
316    parser.add_argument('case_group', choices=['example_test', 'custom_test', 'unit_test', 'component_ut'])
317    parser.add_argument('test_case_paths', nargs='+', help='test case folder or file')
318    parser.add_argument('-c', '--config', help='gitlab ci config file')
319    parser.add_argument('-o', '--output', help='output path of config files')
320    parser.add_argument('--pipeline_id', '-p', type=int, default=None, help='pipeline_id')
321    parser.add_argument('--test-case-file-pattern', help='file name pattern used to find Python test case files')
322    args = parser.parse_args()
323
324    SUPPORTED_TARGETS.extend(PREVIEW_TARGETS)
325
326    test_case_paths = [os.path.join(IDF_PATH_FROM_ENV, path) if not os.path.isabs(path) else path for path in args.test_case_paths]  # type: ignore
327    args_list = [test_case_paths, args.config]
328    if args.case_group == 'example_test':
329        assigner = ExampleAssignTest(*args_list)
330    elif args.case_group == 'custom_test':
331        assigner = TestAppsAssignTest(*args_list)
332    elif args.case_group == 'unit_test':
333        assigner = UnitTestAssignTest(*args_list)
334    elif args.case_group == 'component_ut':
335        assigner = ComponentUTAssignTest(*args_list)
336    else:
337        raise SystemExit(1)  # which is impossible
338
339    if args.test_case_file_pattern:
340        assigner.CI_TEST_JOB_PATTERN = re.compile(r'{}'.format(args.test_case_file_pattern))
341
342    assigner.assign_cases()
343    assigner.output_configs(args.output)
344    assigner.create_artifact_index_file()
345