1# Copyright (c) 2024 Tenstorrent AI ULC
2#
3# SPDX-License-Identifier: Apache-2.0
4
5import argparse
6import hashlib
7import os
8import shlex
9import subprocess
10import textwrap
11from pathlib import Path
12
13import pykwalify.core
14import yaml
15from west.commands import WestCommand
16
17try:
18    from yaml import CSafeLoader as SafeLoader
19except ImportError:
20    from yaml import SafeLoader
21
22WEST_PATCH_SCHEMA_PATH = Path(__file__).parents[1] / "schemas" / "patch-schema.yml"
23with open(WEST_PATCH_SCHEMA_PATH) as f:
24    patches_schema = yaml.load(f, Loader=SafeLoader)
25
26WEST_PATCH_BASE = Path("zephyr") / "patches"
27WEST_PATCH_YAML = Path("zephyr") / "patches.yml"
28
29_WEST_MANIFEST_DIR = Path("WEST_MANIFEST_DIR")
30_WEST_TOPDIR = Path("WEST_TOPDIR")
31
32
33class Patch(WestCommand):
34    def __init__(self):
35        super().__init__(
36            "patch",
37            "apply patches to the west workspace",
38            "Apply patches to the west workspace",
39            accepts_unknown_args=False,
40        )
41
42    def do_add_parser(self, parser_adder):
43        parser = parser_adder.add_parser(
44            self.name,
45            help=self.help,
46            formatter_class=argparse.RawDescriptionHelpFormatter,
47            description=self.description,
48            epilog=textwrap.dedent("""\
49            Applying Patches:
50
51                Run "west patch apply" to apply patches.
52                See "west patch apply --help" for details.
53
54            Cleaning Patches:
55
56                Run "west patch clean" to clean patches.
57                See "west patch clean --help" for details.
58
59            Listing Patches:
60
61                Run "west patch list" to list patches.
62                See "west patch list --help" for details.
63
64            YAML File Format:
65
66            The patches.yml syntax is described in "scripts/schemas/patch-schema.yml".
67
68            patches:
69              - path: zephyr/kernel-pipe-fix-not-k-no-wait-and-ge-min-xfer-bytes.patch
70                sha256sum: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
71                module: zephyr
72                author: Kermit D. Frog
73                email: itsnoteasy@being.gr
74                date: 2020-04-20
75                upstreamable: true
76                merge-pr: https://github.com/zephyrproject-rtos/zephyr/pull/24486
77                issue: https://github.com/zephyrproject-rtos/zephyr/issues/24485
78                merge-status: true
79                merge-commit: af926ae728c78affa89cbc1de811ab4211ed0f69
80                merge-date: 2020-04-27
81                apply-command: git apply
82                comments: |
83                  Songs about rainbows - why are there so many??
84                custom:
85                  possible-muppets-to-ask-for-clarification-with-the-above-question:
86                    - Miss Piggy
87                    - Gonzo
88                    - Fozzie Bear
89                    - Animal
90            """),
91        )
92
93        parser.add_argument(
94            "-b",
95            "--patch-base",
96            help="Directory containing patch files",
97            metavar="DIR",
98            default=_WEST_MANIFEST_DIR / WEST_PATCH_BASE,
99            type=Path,
100        )
101        parser.add_argument(
102            "-l",
103            "--patch-yml",
104            help="Path to patches.yml file",
105            metavar="FILE",
106            default=_WEST_MANIFEST_DIR / WEST_PATCH_YAML,
107            type=Path,
108        )
109        parser.add_argument(
110            "-w",
111            "--west-workspace",
112            help="West workspace",
113            metavar="DIR",
114            default=_WEST_TOPDIR,
115            type=Path,
116        )
117
118        subparsers = parser.add_subparsers(
119            dest="subcommand",
120            metavar="<subcommand>",
121            help="select a subcommand. If omitted treat it as 'list'",
122        )
123
124        apply_arg_parser = subparsers.add_parser(
125            "apply",
126            help="Apply patches",
127            formatter_class=argparse.RawDescriptionHelpFormatter,
128            epilog=textwrap.dedent(
129                """
130            Applying Patches:
131
132                Run "west patch apply" to apply patches.
133            """
134            ),
135        )
136        apply_arg_parser.add_argument(
137            "-r",
138            "--roll-back",
139            help="Roll back if any patch fails to apply",
140            action="store_true",
141            default=False,
142        )
143
144        subparsers.add_parser(
145            "clean",
146            help="Clean patches",
147            formatter_class=argparse.RawDescriptionHelpFormatter,
148            epilog=textwrap.dedent(
149                """
150            Cleaning Patches:
151
152                Run "west patch clean" to clean patches.
153            """
154            ),
155        )
156
157        subparsers.add_parser(
158            "list",
159            help="List patches",
160            formatter_class=argparse.RawDescriptionHelpFormatter,
161            epilog=textwrap.dedent(
162                """
163            Listing Patches:
164
165                Run "west patch list" to list patches.
166            """
167            ),
168        )
169
170        return parser
171
172    def filter_args(self, args):
173        try:
174            manifest_path = self.config.get("manifest.path")
175        except BaseException:
176            self.die("could not retrieve manifest path from west configuration")
177
178        topdir = Path(self.topdir)
179        manifest_dir = topdir / manifest_path
180
181        if args.patch_base.is_relative_to(_WEST_MANIFEST_DIR):
182            args.patch_base = manifest_dir / args.patch_base.relative_to(_WEST_MANIFEST_DIR)
183        if args.patch_yml.is_relative_to(_WEST_MANIFEST_DIR):
184            args.patch_yml = manifest_dir / args.patch_yml.relative_to(_WEST_MANIFEST_DIR)
185        if args.west_workspace.is_relative_to(_WEST_TOPDIR):
186            args.west_workspace = topdir / args.west_workspace.relative_to(_WEST_TOPDIR)
187
188    def do_run(self, args, _):
189        self.filter_args(args)
190
191        if not os.path.isfile(args.patch_yml):
192            self.inf(f"no patches to apply: {args.patch_yml} not found")
193            return
194
195        west_config = Path(args.west_workspace) / ".west" / "config"
196        if not os.path.isfile(west_config):
197            self.die(f"{args.west_workspace} is not a valid west workspace")
198
199        try:
200            with open(args.patch_yml) as f:
201                yml = yaml.load(f, Loader=SafeLoader)
202            if not yml:
203                self.inf(f"{args.patch_yml} is empty")
204                return
205            pykwalify.core.Core(source_data=yml, schema_data=patches_schema).validate()
206        except (yaml.YAMLError, pykwalify.errors.SchemaError) as e:
207            self.die(f"ERROR: Malformed yaml {args.patch_yml}: {e}")
208
209        if not args.subcommand:
210            args.subcommand = "list"
211
212        method = {
213            "apply": self.apply,
214            "clean": self.clean,
215            "list": self.list,
216        }
217
218        method[args.subcommand](args, yml)
219
220    def apply(self, args, yml):
221        patches = yml.get("patches", [])
222        if not patches:
223            return
224
225        patch_count = 0
226        failed_patch = None
227        patched_mods = set()
228
229        for patch_info in patches:
230            pth = patch_info["path"]
231            patch_path = os.path.realpath(Path(args.patch_base) / pth)
232
233            apply_cmd = patch_info["apply-command"]
234            apply_cmd_list = shlex.split(apply_cmd)
235
236            self.dbg(f"reading patch file {pth}")
237            patch_file_data = None
238
239            try:
240                with open(patch_path, "rb") as pf:
241                    patch_file_data = pf.read()
242            except Exception as e:
243                self.err(f"failed to read {pth}: {e}")
244                failed_patch = pth
245                break
246
247            self.dbg("checking patch integrity... ", end="")
248            expect_sha256 = patch_info["sha256sum"]
249            hasher = hashlib.sha256()
250            hasher.update(patch_file_data)
251            actual_sha256 = hasher.hexdigest()
252            if actual_sha256 != expect_sha256:
253                self.dbg("FAIL")
254                self.err(
255                    f"sha256 mismatch for {pth}:\n"
256                    f"expect: {expect_sha256}\n"
257                    f"actual: {actual_sha256}"
258                )
259                failed_patch = pth
260                break
261            self.dbg("OK")
262            patch_count += 1
263            patch_file_data = None
264
265            mod = patch_info["module"]
266            mod_path = Path(args.west_workspace) / mod
267            patched_mods.add(mod)
268
269            self.dbg(f"patching {mod}... ", end="")
270            origdir = os.getcwd()
271            os.chdir(mod_path)
272            apply_cmd += patch_path
273            apply_cmd_list.extend([patch_path])
274            proc = subprocess.run(apply_cmd_list)
275            if proc.returncode:
276                self.dbg("FAIL")
277                self.err(proc.stderr)
278                failed_patch = pth
279                break
280            self.dbg("OK")
281            os.chdir(origdir)
282
283        if not failed_patch:
284            self.inf(f"{patch_count} patches applied successfully \\o/")
285            return
286
287        if args.roll_back:
288            self.clean(args, yml, patched_mods)
289
290        self.die(f"failed to apply patch {pth}")
291
292    def clean(self, args, yml, mods=None):
293        clean_cmd = yml["clean-command"]
294        checkout_cmd = yml["checkout-command"]
295
296        if not clean_cmd and not checkout_cmd:
297            self.dbg("no clean or checkout commands specified")
298            return
299
300        clean_cmd_list = shlex.split(clean_cmd)
301        checkout_cmd_list = shlex.split(checkout_cmd)
302
303        origdir = os.getcwd()
304        for mod, mod_path in Patch.get_mod_paths(args, yml).items():
305            if mods and mod not in mods:
306                continue
307            try:
308                os.chdir(mod_path)
309
310                if checkout_cmd:
311                    self.dbg(f"Running '{checkout_cmd}' in {mod}.. ", end="")
312                    proc = subprocess.run(checkout_cmd_list, capture_output=True)
313                    if proc.returncode:
314                        self.dbg("FAIL")
315                        self.err(f"{checkout_cmd} failed for {mod}\n{proc.stderr}")
316                    else:
317                        self.dbg("OK")
318
319                if clean_cmd:
320                    self.dbg(f"Running '{clean_cmd}' in {mod}.. ", end="")
321                    proc = subprocess.run(clean_cmd_list, capture_output=True)
322                    if proc.returncode:
323                        self.dbg("FAIL")
324                        self.err(f"{clean_cmd} failed for {mod}\n{proc.stderr}")
325                    else:
326                        self.dbg("OK")
327
328            except Exception as e:
329                # If this fails for some reason, just log it and continue
330                self.err(f"failed to clean up {mod}: {e}")
331
332        os.chdir(origdir)
333
334    def list(self, args, yml):
335        patches = yml.get("patches", [])
336        if not patches:
337            return
338
339        for patch_info in patches:
340            self.inf(patch_info)
341
342    @staticmethod
343    def get_mod_paths(args, yml):
344        patches = yml.get("patches", [])
345        if not patches:
346            return {}
347
348        mod_paths = {}
349        for patch_info in patches:
350            mod = patch_info["module"]
351            mod_path = os.path.realpath(Path(args.west_workspace) / mod)
352            mod_paths[mod] = mod_path
353
354        return mod_paths
355