1#!/usr/bin/env python3
2
3import os
4import re
5import argparse
6from typing import List, Union
7
8
9def get_arg():
10    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, description=""
11                                     "Apply the specified version to affected source files. Eg.:\n"
12                                     " python3 update_version.py 9.1.2-dev\n"
13                                     " python3 update_version.py 9.2.0"
14                                     )
15    parser.add_argument('version', metavar='version', type=str,
16                        help='The version to apply')
17
18    return parser.parse_args()
19
20
21class Version:
22    RE_PATTERN = r"(\d+)\.(\d+)\.(\d+)(-[\w\d]+)?"
23
24    def __init__(self, user_input: str):
25
26        if not re.match(r'^' + self.RE_PATTERN + r'$', user_input):
27            raise Exception(f"Invalid version format: {user_input}")
28
29        groups = re.search(self.RE_PATTERN, user_input).groups()
30
31        self.major = groups[0]
32        self.minor = groups[1]
33        self.patch = groups[2]
34        self.info = groups[3].lstrip('-') if groups[3] else ""
35
36        self.is_release = len(self.info) == 0
37        self.as_string = user_input
38
39    def __str__(self):
40        return self.as_string
41
42
43class RepoFileVersionReplacer:
44    DIR_SCRIPTS = os.path.dirname(__file__)
45    DIR_REPO_ROOT = os.path.join(DIR_SCRIPTS, "..")
46
47    def __init__(self, relative_path_segments: List[str], expected_occurrences: int):
48        self.path_relative = os.path.join(*relative_path_segments)
49        self.path = os.path.join(self.DIR_REPO_ROOT, self.path_relative)
50        self.expected_occurrences = expected_occurrences
51
52    def applyVersionToLine(self, line: str, version: Version) -> Union[str, None]:
53        return None
54
55    def applyVersion(self, version: Version):
56        with open(self.path, 'r', encoding='utf-8') as file:
57            lines = file.readlines()
58
59        occurrences = 0
60        for i, line in enumerate(lines):
61            line_with_version = self.applyVersionToLine(line, version)
62            if line_with_version:
63                lines[i] = line_with_version
64                occurrences += 1
65
66        # not perfect, but will catch obvious pitfalls
67        if occurrences != self.expected_occurrences:
68            raise Exception(f"Bad lines in {self.path_relative}")
69
70        with open(self.path, 'w', encoding='utf-8') as file:
71            file.writelines(lines)
72
73
74class PrefixReplacer(RepoFileVersionReplacer):
75
76    def __init__(self, relative_path_segments: List[str], prefix: str, expected_occurrences=1):
77        super().__init__(relative_path_segments, expected_occurrences)
78        self.prefix = prefix
79
80    def applyVersionToLine(self, line: str, version: Version):
81        pattern = r'(' + re.escape(self.prefix) + ')' + Version.RE_PATTERN
82        repl = r'\g<1>' + str(version)
83        replaced, n = re.subn(pattern, repl, line)
84        return replaced if n > 0 else None
85
86
87class MacroReplacer(RepoFileVersionReplacer):
88    def __init__(self, relative_path_segments: List[str]):
89        super().__init__(relative_path_segments, 4)
90
91    def applyVersionToLine(self, line: str, version: Version):
92        targets = {
93            'LVGL_VERSION_MAJOR': version.major,
94            'LVGL_VERSION_MINOR': version.minor,
95            'LVGL_VERSION_PATCH': version.patch,
96            'LVGL_VERSION_INFO': version.info,
97        }
98
99        for key, val in targets.items():
100            pattern = self.getPattern(key)
101            repl = self.getReplacement(val)
102            replaced, n = re.subn(pattern, repl, line)
103            if n > 0:
104                return replaced
105
106        return None
107
108    def getPattern(self, key: str):
109        return r'(^#define ' + key + r' +).+'
110
111    def getReplacement(self, val: str):
112        if not val.isnumeric():
113            val = f'"{val}"'
114
115        return r'\g<1>' + val
116
117
118class CmakeReplacer(MacroReplacer):
119    def getPattern(self, key: str):
120        return r'(^set\(' + key + r' +")([^"]*)(.+)'
121
122    def getReplacement(self, val: str):
123        return r'\g<1>' + val + r'\g<3>'
124
125class KconfigReplacer(RepoFileVersionReplacer):
126    """Replace version info in Kconfig file"""
127
128    def __init__(self, relative_path_segments: List[str]):
129        super().__init__(relative_path_segments, 3)
130
131    def applyVersionToLine(self, line: str, version: Version):
132        targets = {
133            'LVGL_VERSION_MAJOR': version.major,
134            'LVGL_VERSION_MINOR': version.minor,
135            'LVGL_VERSION_PATCH': version.patch,
136        }
137
138        for key, val in targets.items():
139            pattern = self.getPattern(key)
140            repl = self.getReplacement(val)
141            replaced, n = re.subn(pattern, repl, line)
142            if n > 0:
143                return replaced
144
145        return None
146    def getPattern(self, key: str):
147        # Match the version fields in Kconfig file
148        return rf'(^\s+default\s+)(\d+) # ({key})'
149
150    def getReplacement(self, val: str):
151        # Replace the version value
152        return r'\g<1>' + val + r' # \g<3>'
153
154
155if __name__ == '__main__':
156    args = get_arg()
157
158    version = Version(args.version)
159    print(f"Applying version {version} to:")
160
161    targets = [
162        MacroReplacer(['lv_version.h']),
163        CmakeReplacer(['env_support', 'cmake', 'version.cmake']),
164        PrefixReplacer(['lv_conf_template.h'], 'Configuration file for v'),
165        KconfigReplacer(['Kconfig']),
166    ]
167
168    if version.is_release:
169        targets.extend([
170            PrefixReplacer(['library.json'], '"version": "'),
171            PrefixReplacer(['library.properties'], 'version='),
172            PrefixReplacer(['Kconfig'], 'Kconfig file for LVGL v'),
173        ])
174
175    for target in targets:
176        print(f"  - {target.path_relative}")
177        target.applyVersion(version)
178