1# Copyright (c) 2024 Tenstorrent AI ULC
2#
3# SPDX-License-Identifier: Apache-2.0
4
5import argparse
6import hashlib
7import os
8import re
9import shlex
10import subprocess
11import sys
12import textwrap
13import urllib.request
14from pathlib import Path
15
16import pykwalify.core
17import yaml
18from west.commands import WestCommand
19
20sys.path.append(os.fspath(Path(__file__).parent.parent))
21import zephyr_module
22from zephyr_ext_common import ZEPHYR_BASE
23
24try:
25    from yaml import CSafeDumper as SafeDumper
26    from yaml import CSafeLoader as SafeLoader
27except ImportError:
28    from yaml import SafeDumper, SafeLoader
29
30WEST_PATCH_SCHEMA_PATH = Path(__file__).parents[1] / "schemas" / "patch-schema.yml"
31with open(WEST_PATCH_SCHEMA_PATH) as f:
32    patches_schema = yaml.load(f, Loader=SafeLoader)
33
34WEST_PATCH_BASE = Path("zephyr") / "patches"
35WEST_PATCH_YAML = Path("zephyr") / "patches.yml"
36
37
38class Patch(WestCommand):
39    def __init__(self):
40        super().__init__(
41            "patch",
42            "apply patches to the west workspace",
43            "Apply patches to the west workspace",
44            accepts_unknown_args=False,
45        )
46
47    def do_add_parser(self, parser_adder):
48        parser = parser_adder.add_parser(
49            self.name,
50            help=self.help,
51            formatter_class=argparse.RawDescriptionHelpFormatter,
52            description=self.description,
53            epilog=textwrap.dedent("""\
54            Applying Patches:
55
56                Run "west patch apply" to apply patches.
57                See "west patch apply --help" for details.
58
59            Cleaning Patches:
60
61                Run "west patch clean" to clean patches.
62                See "west patch clean --help" for details.
63
64            Listing Patches:
65
66                Run "west patch list" to list patches.
67                See "west patch list --help" for details.
68
69            Fetching Patches:
70
71                Run "west patch gh-fetch" to fetch patches from Github.
72                See "west patch gh-fetch --help" for details.
73
74            YAML File Format:
75
76            The patches.yml syntax is described in "scripts/schemas/patch-schema.yml".
77
78            patches:
79              - path: zephyr/kernel-pipe-fix-not-k-no-wait-and-ge-min-xfer-bytes.patch
80                sha256sum: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
81                module: zephyr
82                author: Kermit D. Frog
83                email: itsnoteasy@being.gr
84                date: 2020-04-20
85                upstreamable: true
86                merge-pr: https://github.com/zephyrproject-rtos/zephyr/pull/24486
87                issue: https://github.com/zephyrproject-rtos/zephyr/issues/24485
88                merge-status: true
89                merge-commit: af926ae728c78affa89cbc1de811ab4211ed0f69
90                merge-date: 2020-04-27
91                apply-command: git apply
92                comments: |
93                  Songs about rainbows - why are there so many??
94                custom:
95                  possible-muppets-to-ask-for-clarification-with-the-above-question:
96                    - Miss Piggy
97                    - Gonzo
98                    - Fozzie Bear
99                    - Animal
100            """),
101        )
102
103        parser.add_argument(
104            "-b",
105            "--patch-base",
106            help=f"""
107                Directory containing patch files (absolute or relative to module dir,
108                default: {WEST_PATCH_BASE})""",
109            metavar="DIR",
110            type=Path,
111        )
112        parser.add_argument(
113            "-l",
114            "--patch-yml",
115            help=f"""
116                Path to patches.yml file (absolute or relative to module dir,
117                default: {WEST_PATCH_YAML})""",
118            metavar="FILE",
119            type=Path,
120        )
121        parser.add_argument(
122            "-w",
123            "--west-workspace",
124            help="West workspace",
125            metavar="DIR",
126            type=Path,
127        )
128        parser.add_argument(
129            "-sm",
130            "--src-module",
131            dest="src_module",
132            metavar="MODULE",
133            type=str,
134            help="""
135                Zephyr module containing the patch definition (name, absolute path or
136                path relative to west-workspace)""",
137        )
138        parser.add_argument(
139            "-dm",
140            "--dst-module",
141            action="append",
142            dest="dst_modules",
143            metavar="MODULE",
144            type=str,
145            help="""
146                Zephyr module to run the 'patch' command for.
147                Option can be passed multiple times.
148                If this option is not given, the 'patch' command will run for Zephyr
149                and all modules.""",
150        )
151
152        subparsers = parser.add_subparsers(
153            dest="subcommand",
154            metavar="<subcommand>",
155            help="select a subcommand. If omitted treat it as 'list'",
156        )
157
158        apply_arg_parser = subparsers.add_parser(
159            "apply",
160            help="Apply patches",
161            formatter_class=argparse.RawDescriptionHelpFormatter,
162            epilog=textwrap.dedent(
163                """
164            Applying Patches:
165
166                Run "west patch apply" to apply patches.
167            """
168            ),
169        )
170        apply_arg_parser.add_argument(
171            "-r",
172            "--roll-back",
173            help="Roll back if any patch fails to apply",
174            action="store_true",
175            default=False,
176        )
177
178        subparsers.add_parser(
179            "clean",
180            help="Clean patches",
181            formatter_class=argparse.RawDescriptionHelpFormatter,
182            epilog=textwrap.dedent(
183                """
184            Cleaning Patches:
185
186                Run "west patch clean" to clean patches.
187            """
188            ),
189        )
190
191        gh_fetch_arg_parser = subparsers.add_parser(
192            "gh-fetch",
193            help="Fetch patch from Github",
194            formatter_class=argparse.RawDescriptionHelpFormatter,
195            epilog=textwrap.dedent(
196                """
197            Fetching Patches from Github:
198
199                Run "west patch gh-fetch" to fetch a PR from Github and store it as a patch.
200                The meta data is generated and appended to the provided patches.yml file.
201
202                If no patches.yml file exists, it will be created.
203            """
204            ),
205        )
206        gh_fetch_arg_parser.add_argument(
207            "-o",
208            "--owner",
209            action="store",
210            default="zephyrproject-rtos",
211            help="Github repository owner",
212        )
213        gh_fetch_arg_parser.add_argument(
214            "-r",
215            "--repo",
216            action="store",
217            default="zephyr",
218            help="Github repository",
219        )
220        gh_fetch_arg_parser.add_argument(
221            "-pr",
222            "--pull-request",
223            metavar="ID",
224            action="store",
225            required=True,
226            type=int,
227            help="Github Pull Request ID",
228        )
229        gh_fetch_arg_parser.add_argument(
230            "-m",
231            "--module",
232            metavar="DIR",
233            action="store",
234            required=True,
235            type=Path,
236            help="Module path",
237        )
238        gh_fetch_arg_parser.add_argument(
239            "-s",
240            "--split-commits",
241            action="store_true",
242            help="Create patch files for each commit instead of a single patch for the entire PR",
243        )
244        gh_fetch_arg_parser.add_argument(
245            '-t',
246            '--token',
247            metavar='FILE',
248            dest='tokenfile',
249            help='File containing GitHub token (alternatively, use GITHUB_TOKEN env variable)',
250        )
251
252        subparsers.add_parser(
253            "list",
254            help="List patches",
255            formatter_class=argparse.RawDescriptionHelpFormatter,
256            epilog=textwrap.dedent(
257                """
258            Listing Patches:
259
260                Run "west patch list" to list patches.
261            """
262            ),
263        )
264
265        return parser
266
267    def filter_args(self, args):
268        try:
269            manifest_path = self.config.get("manifest.path")
270        except BaseException:
271            self.die("could not retrieve manifest path from west configuration")
272
273        topdir = Path(self.topdir)
274
275        if args.src_module is not None:
276            mod_path = self.get_module_path(args.src_module)
277            if mod_path is None:
278                self.die(f'Source module "{args.src_module}" not found')
279            if args.patch_base is not None and args.patch_base.is_absolute():
280                self.die("patch-base must not be an absolute path in combination with src-module")
281            if args.patch_yml is not None and args.patch_yml.is_absolute():
282                self.die("patch-yml must not be an absolute path in combination with src-module")
283            manifest_dir = topdir / mod_path
284        else:
285            manifest_dir = topdir / manifest_path
286
287        if args.patch_base is None:
288            args.patch_base = manifest_dir / WEST_PATCH_BASE
289        if not args.patch_base.is_absolute():
290            args.patch_base = manifest_dir / args.patch_base
291
292        if args.patch_yml is None:
293            args.patch_yml = manifest_dir / WEST_PATCH_YAML
294        elif not args.patch_yml.is_absolute():
295            args.patch_yml = manifest_dir / args.patch_yml
296
297        if args.west_workspace is None:
298            args.west_workspace = topdir
299        elif not args.west_workspace.is_absolute():
300            args.west_workspace = topdir / args.west_workspace
301
302        if args.dst_modules is not None:
303            args.dst_modules = [self.get_module_path(m) for m in args.dst_modules]
304
305    def load_yml(self, args, allow_missing):
306        if not os.path.isfile(args.patch_yml):
307            if not allow_missing:
308                self.inf(f"no patches to apply: {args.patch_yml} not found")
309                return None
310
311            # Return the schema defaults
312            return pykwalify.core.Core(source_data={}, schema_data=patches_schema).validate()
313
314        try:
315            with open(args.patch_yml) as f:
316                yml = yaml.load(f, Loader=SafeLoader)
317            return pykwalify.core.Core(source_data=yml, schema_data=patches_schema).validate()
318        except (yaml.YAMLError, pykwalify.errors.SchemaError) as e:
319            self.die(f"ERROR: Malformed yaml {args.patch_yml}: {e}")
320
321    def do_run(self, args, _):
322        self.filter_args(args)
323
324        west_config = Path(args.west_workspace) / ".west" / "config"
325        if not os.path.isfile(west_config):
326            self.die(f"{args.west_workspace} is not a valid west workspace")
327
328        yml = self.load_yml(args, args.subcommand in ["gh-fetch"])
329        if yml is None:
330            return
331
332        if not args.subcommand:
333            args.subcommand = "list"
334
335        method = {
336            "apply": self.apply,
337            "clean": self.clean,
338            "list": self.list,
339            "gh-fetch": self.gh_fetch,
340        }
341
342        method[args.subcommand](args, yml, args.dst_modules)
343
344    def apply(self, args, yml, dst_mods=None):
345        patches = yml.get("patches", [])
346        if not patches:
347            return
348
349        patch_count = 0
350        failed_patch = None
351        patched_mods = set()
352
353        for patch_info in patches:
354            mod = self.get_module_path(patch_info["module"])
355            if mod is None:
356                continue
357
358            if dst_mods and mod not in dst_mods:
359                continue
360
361            pth = patch_info["path"]
362            patch_path = os.path.realpath(Path(args.patch_base) / pth)
363
364            apply_cmd = patch_info["apply-command"]
365            apply_cmd_list = shlex.split(apply_cmd)
366
367            self.dbg(f"reading patch file {pth}")
368            expect_sha256 = patch_info["sha256sum"]
369            try:
370                actual_sha256 = self.get_file_sha256sum(patch_path)
371            except Exception as e:
372                self.err(f"failed to read {pth}: {e}")
373                failed_patch = pth
374                break
375
376            if actual_sha256 != expect_sha256:
377                self.dbg("FAIL")
378                self.err(
379                    f"sha256 mismatch for {pth}:\n"
380                    f"expect: {expect_sha256}\n"
381                    f"actual: {actual_sha256}"
382                )
383                failed_patch = pth
384                break
385            self.dbg("OK")
386            patch_count += 1
387
388            mod_path = Path(args.west_workspace) / mod
389            patched_mods.add(mod)
390
391            self.dbg(f"patching {mod}... ", end="")
392            apply_cmd += patch_path
393            apply_cmd_list.extend([patch_path])
394            proc = subprocess.run(
395                apply_cmd_list, capture_output=True, cwd=mod_path, encoding="utf-8"
396            )
397            if proc.returncode:
398                self.dbg("FAIL")
399                self.err(proc.stderr)
400                failed_patch = pth
401                break
402            self.dbg("OK")
403
404        if not failed_patch:
405            self.inf(f"{patch_count} patches applied successfully \\o/")
406            return
407
408        if args.roll_back:
409            self.clean(args, yml, patched_mods)
410
411        self.die(f"failed to apply patch {failed_patch}")
412
413    def clean(self, args, yml, dst_mods=None):
414        clean_cmd = yml["clean-command"]
415        checkout_cmd = yml["checkout-command"]
416
417        if not clean_cmd and not checkout_cmd:
418            self.dbg("no clean or checkout commands specified")
419            return
420
421        clean_cmd_list = shlex.split(clean_cmd)
422        checkout_cmd_list = shlex.split(checkout_cmd)
423
424        for mod in yml.get("patches", []):
425            m = self.get_module_path(mod.get("module"))
426            if m is None:
427                continue
428            if dst_mods and m not in dst_mods:
429                continue
430            mod_path = Path(args.west_workspace) / m
431
432            try:
433                if checkout_cmd:
434                    self.dbg(f"Running '{checkout_cmd}' in {mod}.. ", end="")
435                    proc = subprocess.run(
436                        checkout_cmd_list, capture_output=True, cwd=mod_path, encoding="utf-8"
437                    )
438                    if proc.returncode:
439                        self.dbg("FAIL")
440                        self.err(f"{checkout_cmd} failed for {mod}\n{proc.stderr}")
441                    else:
442                        self.dbg("OK")
443
444                if clean_cmd:
445                    self.dbg(f"Running '{clean_cmd}' in {mod}.. ", end="")
446                    proc = subprocess.run(
447                        clean_cmd_list, capture_output=True, cwd=mod_path, encoding="utf-8"
448                    )
449                    if proc.returncode:
450                        self.dbg("FAIL")
451                        self.err(f"{clean_cmd} failed for {mod}\n{proc.stderr}")
452                    else:
453                        self.dbg("OK")
454
455            except Exception as e:
456                # If this fails for some reason, just log it and continue
457                self.err(f"failed to clean up {mod}: {e}")
458
459    def list(self, args, yml, dst_mods=None):
460        patches = yml.get("patches", [])
461        if not patches:
462            return
463
464        for patch_info in patches:
465            if dst_mods and self.get_module_path(patch_info["module"]) not in dst_mods:
466                continue
467            self.inf(patch_info)
468
469    def gh_fetch(self, args, yml, mods=None):
470        if mods:
471            self.die(
472                "Module filters are not available for the gh-fetch subcommand, "
473                "pass a single -m/--module argument after the subcommand."
474            )
475
476        try:
477            from github import Auth, Github
478        except ImportError:
479            self.die("PyGithub not found; can be installed with 'pip install PyGithub'")
480
481        gh = Github(auth=Auth.Token(args.tokenfile) if args.tokenfile else None)
482        pr = gh.get_repo(f"{args.owner}/{args.repo}").get_pull(args.pull_request)
483        args.patch_base.mkdir(parents=True, exist_ok=True)
484
485        if args.split_commits:
486            for cm in pr.get_commits():
487                subject = cm.commit.message.splitlines()[0]
488                filename = "-".join(filter(None, re.split("[^a-zA-Z0-9]+", subject))) + ".patch"
489
490                # No patch URL is provided by the API, but appending .patch to the HTML works too
491                urllib.request.urlretrieve(f"{cm.html_url}.patch", args.patch_base / filename)
492
493                patch_info = {
494                    "path": filename,
495                    "sha256sum": self.get_file_sha256sum(args.patch_base / filename),
496                    "module": str(args.module),
497                    "author": cm.commit.author.name or "Hidden",
498                    "email": cm.commit.author.email or "hidden@github.com",
499                    "date": cm.commit.author.date.strftime("%Y-%m-%d"),
500                    "upstreamable": True,
501                    "merge-pr": pr.html_url,
502                    "merge-status": pr.merged,
503                }
504
505                yml.setdefault("patches", []).append(patch_info)
506        else:
507            filename = "-".join(filter(None, re.split("[^a-zA-Z0-9]+", pr.title))) + ".patch"
508            urllib.request.urlretrieve(pr.patch_url, args.patch_base / filename)
509
510            patch_info = {
511                "path": filename,
512                "sha256sum": self.get_file_sha256sum(args.patch_base / filename),
513                "module": str(args.module),
514                "author": pr.user.name or "Hidden",
515                "email": pr.user.email or "hidden@github.com",
516                "date": pr.created_at.strftime("%Y-%m-%d"),
517                "upstreamable": True,
518                "merge-pr": pr.html_url,
519                "merge-status": pr.merged,
520            }
521
522            yml.setdefault("patches", []).append(patch_info)
523
524        args.patch_yml.parent.mkdir(parents=True, exist_ok=True)
525        with open(args.patch_yml, "w") as f:
526            yaml.dump(yml, f, Dumper=SafeDumper)
527
528    @staticmethod
529    def get_file_sha256sum(filename: Path) -> str:
530        # Read as text to normalize line endings
531        with open(filename, encoding="utf-8", newline=None) as fp:
532            content = fp.read()
533
534        # NOTE: If python 3.11 is the minimum, the following can be replaced with:
535        # digest = hashlib.file_digest(BytesIO(content_bytes), "sha256")
536        digest = hashlib.new("sha256")
537        digest.update(content.encode("utf-8"))
538
539        return digest.hexdigest()
540
541    def get_module_path(self, module_name_or_path):
542        if module_name_or_path is None:
543            return None
544
545        topdir = Path(self.topdir)
546
547        if Path(module_name_or_path).is_absolute():
548            if Path(module_name_or_path).is_dir():
549                return Path(module_name_or_path).resolve().relative_to(topdir)
550            return None
551
552        if (topdir / module_name_or_path).is_dir():
553            return Path(module_name_or_path)
554
555        all_modules = zephyr_module.parse_modules(ZEPHYR_BASE, self.manifest)
556        for m in all_modules:
557            if m.meta['name'] == module_name_or_path:
558                return Path(m.project).relative_to(topdir)
559
560        return None
561