1#!/usr/bin/env python
2#
3# Copyright 2021 Espressif Systems (Shanghai) CO LTD
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18import inspect
19import os
20import sys
21from collections import defaultdict
22from itertools import product
23
24import yaml
25
26try:
27    import pygraphviz as pgv
28except ImportError:  # used when pre-commit, skip generating image
29    pass
30
31try:
32    from typing import Union
33except ImportError:  # used for type hint
34    pass
35
36IDF_PATH = os.path.abspath(os.getenv('IDF_PATH', os.path.join(os.path.dirname(__file__), '..', '..', '..')))
37
38
39def _list(str_or_list):  # type: (Union[str, list]) -> list
40    if isinstance(str_or_list, str):
41        return [str_or_list]
42    elif isinstance(str_or_list, list):
43        return str_or_list
44    else:
45        raise ValueError('Wrong type: {}. Only supports str or list.'.format(type(str_or_list)))
46
47
48def _format_nested_dict(_dict, f_tuple):  # type: (dict[str, dict], tuple[str, ...]) -> dict[str, dict]
49    res = {}
50    for k, v in _dict.items():
51        k = k.split('__')[0]
52        if isinstance(v, dict):
53            v = _format_nested_dict(v, f_tuple)
54        elif isinstance(v, list):
55            v = _format_nested_list(v, f_tuple)
56        elif isinstance(v, str):
57            v = v.format(*f_tuple)
58        res[k.format(*f_tuple)] = v
59    return res
60
61
62def _format_nested_list(_list, f_tuple):  # type: (list[str], tuple[str, ...]) -> list[str]
63    res = []
64    for item in _list:
65        if isinstance(item, list):
66            item = _format_nested_list(item, f_tuple)
67        elif isinstance(item, dict):
68            item = _format_nested_dict(item, f_tuple)
69        elif isinstance(item, str):
70            item = item.format(*f_tuple)
71        res.append(item)
72    return res
73
74
75class RulesWriter:
76    AUTO_GENERATE_MARKER = inspect.cleandoc(r'''
77    ##################
78    # Auto Generated #
79    ##################
80    ''')
81
82    LABEL_TEMPLATE = inspect.cleandoc(r'''
83    .if-label-{0}: &if-label-{0}
84      if: '$BOT_LABEL_{1} || $CI_MERGE_REQUEST_LABELS =~ /^(?:[^,\n\r]+,)*{0}(?:,[^,\n\r]+)*$/i'
85    ''')
86
87    RULE_PROTECTED = '    - <<: *if-protected'
88    RULE_PROTECTED_NO_LABEL = '    - <<: *if-protected-no_label'
89    RULE_BUILD_ONLY = '    - <<: *if-label-build-only\n' \
90                      '      when: never'
91    RULE_LABEL_TEMPLATE = '    - <<: *if-label-{0}'
92    RULE_PATTERN_TEMPLATE = '    - <<: *if-dev-push\n' \
93                            '      changes: *patterns-{0}'
94    RULES_TEMPLATE = inspect.cleandoc(r"""
95    .rules:{0}:
96      rules:
97    {1}
98    """)
99
100    KEYWORDS = ['labels', 'patterns']
101
102    def __init__(self, rules_yml, depend_yml):  # type: (str, str) -> None
103        self.rules_yml = rules_yml
104        self.rules_cfg = yaml.load(open(rules_yml), Loader=yaml.FullLoader)
105
106        self.full_cfg = yaml.load(open(depend_yml), Loader=yaml.FullLoader)
107        self.cfg = {k: v for k, v in self.full_cfg.items() if not k.startswith('.')}
108        self.cfg = self.expand_matrices()
109        self.rules = self.expand_rules()
110
111        self.graph = None
112
113    def expand_matrices(self):  # type: () -> dict
114        """
115        Expand the matrix into different rules
116        """
117        res = {}
118        for k, v in self.cfg.items():
119            res.update(self._expand_matrix(k, v))
120
121        for k, v in self.cfg.items():
122            if not v:
123                continue
124            deploy = v.get('deploy')
125            if deploy:
126                for item in _list(deploy):
127                    res['{}-{}'.format(k, item)] = v
128        return res
129
130    @staticmethod
131    def _expand_matrix(name, cfg):  # type: (str, dict) -> dict
132        """
133        Expand matrix into multi keys
134        :param cfg: single rule dict
135        :return:
136        """
137        default = {name: cfg}
138        if not cfg:
139            return default
140        matrices = cfg.pop('matrix', None)
141        if not matrices:
142            return default
143
144        res = {}
145        for comb in product(*_list(matrices)):
146            res.update(_format_nested_dict(default, comb))
147        return res
148
149    def expand_rules(self):  # type: () -> dict[str, dict[str, list]]
150        res = defaultdict(lambda: defaultdict(set))  # type: dict[str, dict[str, set]]
151        for k, v in self.cfg.items():
152            if not v:
153                continue
154            for vk, vv in v.items():
155                if vk in self.KEYWORDS:
156                    res[k][vk] = set(_list(vv))
157                else:
158                    res[k][vk] = vv
159            for key in self.KEYWORDS:  # provide empty set for missing field
160                if key not in res[k]:
161                    res[k][key] = set()
162
163        for k, v in self.cfg.items():
164            if not v:
165                continue
166            if 'included_in' in v:
167                for item in _list(v['included_in']):
168                    if 'labels' in v:
169                        res[item]['labels'].update(_list(v['labels']))
170                    if 'patterns' in v:
171                        for _pat in _list(v['patterns']):
172                            # Patterns must be pre-defined
173                            if '.patterns-{}'.format(_pat) not in self.rules_cfg:
174                                print('WARNING: pattern {} not exists'.format(_pat))
175                                continue
176                            res[item]['patterns'].add(_pat)
177
178        sorted_res = defaultdict(lambda: defaultdict(list))  # type: dict[str, dict[str, list]]
179        for k, v in res.items():
180            for vk, vv in v.items():
181                sorted_res[k][vk] = sorted(vv)
182        return sorted_res
183
184    def new_labels_str(self):  # type: () -> str
185        _labels = set([])
186        for k, v in self.cfg.items():
187            if not v:
188                continue  # shouldn't be possible
189            labels = v.get('labels')
190            if not labels:
191                continue
192            _labels.update(_list(labels))
193        labels = sorted(_labels)
194
195        res = ''
196        res += '\n\n'.join([self._format_label(_label) for _label in labels])
197        return res
198
199    @classmethod
200    def _format_label(cls, label):  # type: (str) -> str
201        return cls.LABEL_TEMPLATE.format(label, cls.bot_label_str(label))
202
203    @staticmethod
204    def bot_label_str(label):  # type: (str) -> str
205        return label.upper().replace('-', '_')
206
207    def new_rules_str(self):  # type: () -> str
208        res = []
209        for k, v in sorted(self.rules.items()):
210            res.append(self.RULES_TEMPLATE.format(k, self._format_rule(k, v)))
211        return '\n\n'.join(res)
212
213    def _format_rule(self, name, cfg):  # type: (str, dict) -> str
214        _rules = []
215        if name.endswith('-production'):
216            _rules.append(self.RULE_PROTECTED_NO_LABEL)
217        else:
218            if not (name.endswith('-preview') or name.startswith('labels:')):
219                _rules.append(self.RULE_PROTECTED)
220            # Special case for esp32c3 example_test, for now it only run with label
221            if name.startswith('test:') or name == 'labels:example_test-esp32c3':
222                _rules.append(self.RULE_BUILD_ONLY)
223            for label in cfg['labels']:
224                _rules.append(self.RULE_LABEL_TEMPLATE.format(label))
225            for pattern in cfg['patterns']:
226                if '.patterns-{}'.format(pattern) in self.rules_cfg:
227                    _rules.append(self.RULE_PATTERN_TEMPLATE.format(pattern))
228                else:
229                    print('WARNING: pattern {} not exists'.format(pattern))
230        return '\n'.join(_rules)
231
232    def update_rules_yml(self):  # type: () -> bool
233        with open(self.rules_yml) as fr:
234            file_str = fr.read()
235
236        auto_generate_str = '\n{}\n\n{}\n'.format(self.new_labels_str(), self.new_rules_str())
237        rest, marker, old = file_str.partition(self.AUTO_GENERATE_MARKER)
238        if old == auto_generate_str:
239            return False
240        else:
241            print(self.rules_yml, 'has been modified. Please check')
242            with open(self.rules_yml, 'w') as fw:
243                fw.write(rest + marker + auto_generate_str)
244            return True
245
246
247LABEL_COLOR = 'green'
248PATTERN_COLOR = 'cyan'
249RULE_COLOR = 'blue'
250
251
252def build_graph(rules_dict):  # type: (dict[str, dict[str, list]]) -> pgv.AGraph
253    graph = pgv.AGraph(directed=True, rankdir='LR', concentrate=True)
254
255    for k, v in rules_dict.items():
256        if not v:
257            continue
258        included_in = v.get('included_in')
259        if included_in:
260            for item in _list(included_in):
261                graph.add_node(k, color=RULE_COLOR)
262                graph.add_node(item, color=RULE_COLOR)
263                graph.add_edge(k, item, color=RULE_COLOR)
264        labels = v.get('labels')
265        if labels:
266            for _label in labels:
267                graph.add_node('label:{}'.format(_label), color=LABEL_COLOR)
268                graph.add_edge('label:{}'.format(_label), k, color=LABEL_COLOR)
269        patterns = v.get('patterns')
270        if patterns:
271            for _pat in patterns:
272                graph.add_node('pattern:{}'.format(_pat), color=PATTERN_COLOR)
273                graph.add_edge('pattern:{}'.format(_pat), k, color=PATTERN_COLOR)
274
275    return graph
276
277
278def output_graph(graph, output_path='output.png'):  # type: (pgv.AGraph, str) -> None
279    graph.layout('dot')
280    if output_path.endswith('.png'):
281        img_path = output_path
282    else:
283        img_path = os.path.join(output_path, 'output.png')
284    graph.draw(img_path)
285
286
287if __name__ == '__main__':
288    parser = argparse.ArgumentParser(description=__doc__)
289    parser.add_argument('rules_yml', nargs='?', default=os.path.join(IDF_PATH, '.gitlab', 'ci', 'rules.yml'),
290                        help='rules.yml file path')
291    parser.add_argument('dependencies_yml', nargs='?', default=os.path.join(IDF_PATH, '.gitlab', 'ci', 'dependencies',
292                                                                            'dependencies.yml'),
293                        help='dependencies.yml file path')
294    parser.add_argument('--graph',
295                        help='Specify PNG image output path. Use this argument to generate dependency graph')
296    args = parser.parse_args()
297
298    writer = RulesWriter(args.rules_yml, args.dependencies_yml)
299    file_modified = writer.update_rules_yml()
300
301    if args.graph:
302        dep_tree_graph = build_graph(writer.rules)
303        output_graph(dep_tree_graph)
304
305    sys.exit(file_modified)
306