1"""
2Command line tool to assign tests to CI test jobs.
3"""
4import argparse
5import errno
6import json
7import os
8import re
9
10import yaml
11
12try:
13    from yaml import CLoader as Loader
14except ImportError:
15    from yaml import Loader as Loader
16
17import gitlab_api
18from tiny_test_fw.Utility import CIAssignTest
19
20try:
21    from idf_py_actions.constants import PREVIEW_TARGETS, SUPPORTED_TARGETS
22except ImportError:
23    SUPPORTED_TARGETS = []
24    PREVIEW_TARGETS = []
25
26IDF_PATH_FROM_ENV = os.getenv('IDF_PATH')
27
28
29class IDFCaseGroup(CIAssignTest.Group):
30    LOCAL_BUILD_DIR = None
31    BUILD_JOB_NAMES = None
32
33    @classmethod
34    def get_artifact_index_file(cls):
35        assert cls.LOCAL_BUILD_DIR
36        if IDF_PATH_FROM_ENV:
37            artifact_index_file = os.path.join(IDF_PATH_FROM_ENV, cls.LOCAL_BUILD_DIR, 'artifact_index.json')
38        else:
39            artifact_index_file = 'artifact_index.json'
40        return artifact_index_file
41
42
43class IDFAssignTest(CIAssignTest.AssignTest):
44    def __init__(self, test_case_path, ci_config_file, case_group=IDFCaseGroup):
45        super(IDFAssignTest, self).__init__(test_case_path, ci_config_file, case_group)
46
47    def format_build_log_path(self, parallel_num):
48        return '{}/list_job_{}.json'.format(self.case_group.LOCAL_BUILD_DIR, parallel_num)
49
50    def create_artifact_index_file(self, project_id=None, pipeline_id=None):
51        if project_id is None:
52            project_id = os.getenv('CI_PROJECT_ID')
53        if pipeline_id is None:
54            pipeline_id = os.getenv('CI_PIPELINE_ID')
55        gitlab_inst = gitlab_api.Gitlab(project_id)
56
57        artifact_index_list = []
58        for build_job_name in self.case_group.BUILD_JOB_NAMES:
59            job_info_list = gitlab_inst.find_job_id(build_job_name, pipeline_id=pipeline_id)
60            for job_info in job_info_list:
61                parallel_num = job_info['parallel_num'] or 1  # Could be None if "parallel_num" not defined for the job
62                raw_data = gitlab_inst.download_artifact(job_info['id'],
63                                                         [self.format_build_log_path(parallel_num)])[0]
64                build_info_list = [json.loads(line) for line in raw_data.decode().splitlines()]
65                for build_info in build_info_list:
66                    build_info['ci_job_id'] = job_info['id']
67                    artifact_index_list.append(build_info)
68        artifact_index_file = self.case_group.get_artifact_index_file()
69        try:
70            os.makedirs(os.path.dirname(artifact_index_file))
71        except OSError as e:
72            if e.errno != errno.EEXIST:
73                raise e
74
75        with open(artifact_index_file, 'w') as f:
76            json.dump(artifact_index_list, f)
77
78
79class ExampleGroup(IDFCaseGroup):
80    SORT_KEYS = CI_JOB_MATCH_KEYS = ['env_tag', 'target']
81
82    LOCAL_BUILD_DIR = 'build_examples'
83    EXAMPLE_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
84    BUILD_JOB_NAMES = ['build_examples_cmake_{}'.format(target) for target in EXAMPLE_TARGETS]
85
86
87class TestAppsGroup(ExampleGroup):
88    LOCAL_BUILD_DIR = 'build_test_apps'
89    TEST_APP_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
90    BUILD_JOB_NAMES = ['build_test_apps_{}'.format(target) for target in TEST_APP_TARGETS]
91
92
93class ComponentUTGroup(TestAppsGroup):
94    LOCAL_BUILD_DIR = 'build_component_ut'
95    UNIT_TEST_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
96    BUILD_JOB_NAMES = ['build_component_ut_{}'.format(target) for target in UNIT_TEST_TARGETS]
97
98
99class UnitTestGroup(IDFCaseGroup):
100    SORT_KEYS = ['test environment', 'tags', 'chip_target']
101    CI_JOB_MATCH_KEYS = ['test environment']
102
103    LOCAL_BUILD_DIR = 'tools/unit-test-app/builds'
104    UNIT_TEST_TARGETS = SUPPORTED_TARGETS + PREVIEW_TARGETS
105    BUILD_JOB_NAMES = ['build_esp_idf_tests_cmake_{}'.format(target) for target in UNIT_TEST_TARGETS]
106
107    MAX_CASE = 50
108    ATTR_CONVERT_TABLE = {
109        'execution_time': 'execution time'
110    }
111    DUT_CLS_NAME = {
112        'esp32': 'ESP32DUT',
113        'esp32s2': 'ESP32S2DUT',
114        'esp32c3': 'ESP32C3DUT',
115        'esp8266': 'ESP8266DUT',
116    }
117
118    def __init__(self, case):
119        super(UnitTestGroup, self).__init__(case)
120        for tag in self._get_case_attr(case, 'tags'):
121            self.ci_job_match_keys.add(tag)
122
123    @staticmethod
124    def _get_case_attr(case, attr):
125        if attr in UnitTestGroup.ATTR_CONVERT_TABLE:
126            attr = UnitTestGroup.ATTR_CONVERT_TABLE[attr]
127        return case[attr]
128
129    def add_extra_case(self, case):
130        """ If current group contains all tags required by case, then add succeed """
131        added = False
132        if self.accept_new_case():
133            for key in self.filters:
134                if self._get_case_attr(case, key) != self.filters[key]:
135                    if key == 'tags':
136                        if set(self._get_case_attr(case, key)).issubset(set(self.filters[key])):
137                            continue
138                    break
139            else:
140                self.case_list.append(case)
141                added = True
142        return added
143
144    def _create_extra_data(self, test_cases, test_function):
145        """
146        For unit test case, we need to copy some attributes of test cases into config file.
147        So unit test function knows how to run the case.
148        """
149        case_data = []
150        for case in test_cases:
151            one_case_data = {
152                'config': self._get_case_attr(case, 'config'),
153                'name': self._get_case_attr(case, 'summary'),
154                'reset': self._get_case_attr(case, 'reset'),
155                'timeout': self._get_case_attr(case, 'timeout'),
156            }
157
158            if test_function in ['run_multiple_devices_cases', 'run_multiple_stage_cases']:
159                try:
160                    one_case_data['child case num'] = self._get_case_attr(case, 'child case num')
161                except KeyError as e:
162                    print('multiple devices/stages cases must contains at least two test functions')
163                    print('case name: {}'.format(one_case_data['name']))
164                    raise e
165
166            case_data.append(one_case_data)
167        return case_data
168
169    def _divide_case_by_test_function(self):
170        """
171        divide cases of current test group by test function they need to use
172
173        :return: dict of list of cases for each test functions
174        """
175        case_by_test_function = {
176            'run_multiple_devices_cases': [],
177            'run_multiple_stage_cases': [],
178            'run_unit_test_cases': [],
179        }
180
181        for case in self.case_list:
182            if case['multi_device'] == 'Yes':
183                case_by_test_function['run_multiple_devices_cases'].append(case)
184            elif case['multi_stage'] == 'Yes':
185                case_by_test_function['run_multiple_stage_cases'].append(case)
186            else:
187                case_by_test_function['run_unit_test_cases'].append(case)
188        return case_by_test_function
189
190    def output(self):
191        """
192        output data for job configs
193
194        :return: {"Filter": case filter, "CaseConfig": list of case configs for cases in this group}
195        """
196
197        target = self._get_case_attr(self.case_list[0], 'chip_target')
198        if target:
199            overwrite = {
200                'dut': {
201                    'package': 'ttfw_idf',
202                    'class': self.DUT_CLS_NAME[target],
203                }
204            }
205        else:
206            overwrite = dict()
207
208        case_by_test_function = self._divide_case_by_test_function()
209
210        output_data = {
211            # we don't need filter for test function, as UT uses a few test functions for all cases
212            'CaseConfig': [
213                {
214                    'name': test_function,
215                    'extra_data': self._create_extra_data(test_cases, test_function),
216                    'overwrite': overwrite,
217                } for test_function, test_cases in case_by_test_function.items() if test_cases
218            ],
219        }
220        return output_data
221
222
223class ExampleAssignTest(IDFAssignTest):
224    CI_TEST_JOB_PATTERN = re.compile(r'^example_test_.+')
225
226    def __init__(self, test_case_path, ci_config_file):
227        super(ExampleAssignTest, self).__init__(test_case_path, ci_config_file, case_group=ExampleGroup)
228
229
230class TestAppsAssignTest(IDFAssignTest):
231    CI_TEST_JOB_PATTERN = re.compile(r'^test_app_test_.+')
232
233    def __init__(self, test_case_path, ci_config_file):
234        super(TestAppsAssignTest, self).__init__(test_case_path, ci_config_file, case_group=TestAppsGroup)
235
236
237class ComponentUTAssignTest(IDFAssignTest):
238    CI_TEST_JOB_PATTERN = re.compile(r'^component_ut_test_.+')
239
240    def __init__(self, test_case_path, ci_config_file):
241        super(ComponentUTAssignTest, self).__init__(test_case_path, ci_config_file, case_group=ComponentUTGroup)
242
243
244class UnitTestAssignTest(IDFAssignTest):
245    CI_TEST_JOB_PATTERN = re.compile(r'^UT_.+')
246
247    def __init__(self, test_case_path, ci_config_file):
248        super(UnitTestAssignTest, self).__init__(test_case_path, ci_config_file, case_group=UnitTestGroup)
249
250    def search_cases(self, case_filter=None):
251        """
252        For unit test case, we don't search for test functions.
253        The unit test cases is stored in a yaml file which is created in job build-idf-test.
254        """
255
256        def find_by_suffix(suffix, path):
257            res = []
258            for root, _, files in os.walk(path):
259                for file in files:
260                    if file.endswith(suffix):
261                        res.append(os.path.join(root, file))
262            return res
263
264        def get_test_cases_from_yml(yml_file):
265            try:
266                with open(yml_file) as fr:
267                    raw_data = yaml.load(fr, Loader=Loader)
268                test_cases = raw_data['test cases']
269            except (IOError, KeyError):
270                return []
271            else:
272                return test_cases
273
274        test_cases = []
275        for path in self.test_case_paths:
276            if os.path.isdir(path):
277                for yml_file in find_by_suffix('.yml', path):
278                    test_cases.extend(get_test_cases_from_yml(yml_file))
279            elif os.path.isfile(path) and path.endswith('.yml'):
280                test_cases.extend(get_test_cases_from_yml(path))
281            else:
282                print('Test case path is invalid. Should only happen when use @bot to skip unit test.')
283
284        # filter keys are lower case. Do map lower case keys with original keys.
285        try:
286            key_mapping = {x.lower(): x for x in test_cases[0].keys()}
287        except IndexError:
288            key_mapping = dict()
289        if case_filter:
290            for key in case_filter:
291                filtered_cases = []
292                for case in test_cases:
293                    try:
294                        mapped_key = key_mapping[key]
295                        # bot converts string to lower case
296                        if isinstance(case[mapped_key], str):
297                            _value = case[mapped_key].lower()
298                        else:
299                            _value = case[mapped_key]
300                        if _value in case_filter[key]:
301                            filtered_cases.append(case)
302                    except KeyError:
303                        # case don't have this key, regard as filter success
304                        filtered_cases.append(case)
305                test_cases = filtered_cases
306        # sort cases with configs and test functions
307        # in later stage cases with similar attributes are more likely to be assigned to the same job
308        # it will reduce the count of flash DUT operations
309        test_cases.sort(key=lambda x: x['config'] + x['multi_stage'] + x['multi_device'])
310        return test_cases
311
312
313if __name__ == '__main__':
314    parser = argparse.ArgumentParser()
315    parser.add_argument('case_group', choices=['example_test', 'custom_test', 'unit_test', 'component_ut'])
316    parser.add_argument('test_case_paths', nargs='+', help='test case folder or file')
317    parser.add_argument('-c', '--config', help='gitlab ci config file')
318    parser.add_argument('-o', '--output', help='output path of config files')
319    parser.add_argument('--pipeline_id', '-p', type=int, default=None, help='pipeline_id')
320    parser.add_argument('--test-case-file-pattern', help='file name pattern used to find Python test case files')
321    args = parser.parse_args()
322
323    SUPPORTED_TARGETS.extend(PREVIEW_TARGETS)
324
325    test_case_paths = [os.path.join(IDF_PATH_FROM_ENV, path) if not os.path.isabs(path) else path for path in args.test_case_paths]
326    args_list = [test_case_paths, args.config]
327    if args.case_group == 'example_test':
328        assigner = ExampleAssignTest(*args_list)
329    elif args.case_group == 'custom_test':
330        assigner = TestAppsAssignTest(*args_list)
331    elif args.case_group == 'unit_test':
332        assigner = UnitTestAssignTest(*args_list)
333    elif args.case_group == 'component_ut':
334        assigner = ComponentUTAssignTest(*args_list)
335    else:
336        raise SystemExit(1)  # which is impossible
337
338    if args.test_case_file_pattern:
339        assigner.CI_TEST_JOB_PATTERN = re.compile(r'{}'.format(args.test_case_file_pattern))
340
341    assigner.assign_cases()
342    assigner.output_configs(args.output)
343    assigner.create_artifact_index_file()
344