1# Copyright (c) 2024 Tenstorrent AI ULC
2#
3# SPDX-License-Identifier: Apache-2.0
4
5import argparse
6import hashlib
7import os
8import re
9import shlex
10import subprocess
11import sys
12import textwrap
13import urllib.request
14from pathlib import Path
15
16import pykwalify.core
17import yaml
18from west.commands import WestCommand
19
20sys.path.append(os.fspath(Path(__file__).parent.parent))
21import zephyr_module
22from zephyr_ext_common import ZEPHYR_BASE
23
24try:
25    from yaml import CSafeDumper as SafeDumper
26    from yaml import CSafeLoader as SafeLoader
27except ImportError:
28    from yaml import SafeDumper, SafeLoader
29
30WEST_PATCH_SCHEMA_PATH = Path(__file__).parents[1] / "schemas" / "patch-schema.yml"
31with open(WEST_PATCH_SCHEMA_PATH) as f:
32    patches_schema = yaml.load(f, Loader=SafeLoader)
33
34WEST_PATCH_BASE = Path("zephyr") / "patches"
35WEST_PATCH_YAML = Path("zephyr") / "patches.yml"
36
37
38class Patch(WestCommand):
39    def __init__(self):
40        super().__init__(
41            "patch",
42            "apply patches to the west workspace",
43            "Apply patches to the west workspace",
44            accepts_unknown_args=False,
45        )
46
47    def do_add_parser(self, parser_adder):
48        parser = parser_adder.add_parser(
49            self.name,
50            help=self.help,
51            formatter_class=argparse.RawDescriptionHelpFormatter,
52            description=self.description,
53            epilog=textwrap.dedent("""\
54            Applying Patches:
55
56                Run "west patch apply" to apply patches.
57                See "west patch apply --help" for details.
58
59            Cleaning Patches:
60
61                Run "west patch clean" to clean patches.
62                See "west patch clean --help" for details.
63
64            Listing Patches:
65
66                Run "west patch list" to list patches.
67                See "west patch list --help" for details.
68
69            Fetching Patches:
70
71                Run "west patch gh-fetch" to fetch patches from Github.
72                See "west patch gh-fetch --help" for details.
73
74            YAML File Format:
75
76            The patches.yml syntax is described in "scripts/schemas/patch-schema.yml".
77
78            patches:
79              - path: zephyr/kernel-pipe-fix-not-k-no-wait-and-ge-min-xfer-bytes.patch
80                sha256sum: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
81                module: zephyr
82                author: Kermit D. Frog
83                email: itsnoteasy@being.gr
84                date: 2020-04-20
85                upstreamable: true
86                merge-pr: https://github.com/zephyrproject-rtos/zephyr/pull/24486
87                issue: https://github.com/zephyrproject-rtos/zephyr/issues/24485
88                merge-status: true
89                merge-commit: af926ae728c78affa89cbc1de811ab4211ed0f69
90                merge-date: 2020-04-27
91                apply-command: git apply
92                comments: |
93                  Songs about rainbows - why are there so many??
94                custom:
95                  possible-muppets-to-ask-for-clarification-with-the-above-question:
96                    - Miss Piggy
97                    - Gonzo
98                    - Fozzie Bear
99                    - Animal
100            """),
101        )
102
103        parser.add_argument(
104            "-b",
105            "--patch-base",
106            help=f"""
107                Directory containing patch files (absolute or relative to module dir,
108                default: {WEST_PATCH_BASE})""",
109            metavar="DIR",
110            type=Path,
111        )
112        parser.add_argument(
113            "-l",
114            "--patch-yml",
115            help=f"""
116                Path to patches.yml file (absolute or relative to module dir,
117                default: {WEST_PATCH_YAML})""",
118            metavar="FILE",
119            type=Path,
120        )
121        parser.add_argument(
122            "-w",
123            "--west-workspace",
124            help="West workspace",
125            metavar="DIR",
126            type=Path,
127        )
128        parser.add_argument(
129            "-sm",
130            "--src-module",
131            dest="src_module",
132            metavar="MODULE",
133            type=str,
134            help="""
135                Zephyr module containing the patch definition (name, absolute path or
136                path relative to west-workspace)""",
137        )
138        parser.add_argument(
139            "-dm",
140            "--dst-module",
141            action="append",
142            dest="dst_modules",
143            metavar="MODULE",
144            type=str,
145            help="""
146                Zephyr module to run the 'patch' command for.
147                Option can be passed multiple times.
148                If this option is not given, the 'patch' command will run for Zephyr
149                and all modules.""",
150        )
151
152        subparsers = parser.add_subparsers(
153            dest="subcommand",
154            metavar="<subcommand>",
155            help="select a subcommand. If omitted treat it as 'list'",
156        )
157
158        apply_arg_parser = subparsers.add_parser(
159            "apply",
160            help="Apply patches",
161            formatter_class=argparse.RawDescriptionHelpFormatter,
162            epilog=textwrap.dedent(
163                """
164            Applying Patches:
165
166                Run "west patch apply" to apply patches.
167            """
168            ),
169        )
170        apply_arg_parser.add_argument(
171            "-r",
172            "--roll-back",
173            help="Roll back if any patch fails to apply",
174            action="store_true",
175            default=False,
176        )
177
178        subparsers.add_parser(
179            "clean",
180            help="Clean patches",
181            formatter_class=argparse.RawDescriptionHelpFormatter,
182            epilog=textwrap.dedent(
183                """
184            Cleaning Patches:
185
186                Run "west patch clean" to clean patches.
187            """
188            ),
189        )
190
191        gh_fetch_arg_parser = subparsers.add_parser(
192            "gh-fetch",
193            help="Fetch patch from Github",
194            formatter_class=argparse.RawDescriptionHelpFormatter,
195            epilog=textwrap.dedent(
196                """
197            Fetching Patches from Github:
198
199                Run "west patch gh-fetch" to fetch a PR from Github and store it as a patch.
200                The meta data is generated and appended to the provided patches.yml file.
201
202                If no patches.yml file exists, it will be created.
203            """
204            ),
205        )
206        gh_fetch_arg_parser.add_argument(
207            "-o",
208            "--owner",
209            action="store",
210            default="zephyrproject-rtos",
211            help="Github repository owner",
212        )
213        gh_fetch_arg_parser.add_argument(
214            "-r",
215            "--repo",
216            action="store",
217            default="zephyr",
218            help="Github repository",
219        )
220        gh_fetch_arg_parser.add_argument(
221            "-pr",
222            "--pull-request",
223            metavar="ID",
224            action="store",
225            required=True,
226            type=int,
227            help="Github Pull Request ID",
228        )
229        gh_fetch_arg_parser.add_argument(
230            "-m",
231            "--module",
232            metavar="DIR",
233            action="store",
234            required=True,
235            type=Path,
236            help="Module path",
237        )
238        gh_fetch_arg_parser.add_argument(
239            "-s",
240            "--split-commits",
241            action="store_true",
242            help="Create patch files for each commit instead of a single patch for the entire PR",
243        )
244        gh_fetch_arg_parser.add_argument(
245            '-t',
246            '--token',
247            metavar='FILE',
248            dest='tokenfile',
249            help='File containing GitHub token (alternatively, use GITHUB_TOKEN env variable)',
250        )
251
252        subparsers.add_parser(
253            "list",
254            help="List patches",
255            formatter_class=argparse.RawDescriptionHelpFormatter,
256            epilog=textwrap.dedent(
257                """
258            Listing Patches:
259
260                Run "west patch list" to list patches.
261            """
262            ),
263        )
264
265        return parser
266
267    def filter_args(self, args):
268        try:
269            manifest_path = self.config.get("manifest.path")
270        except BaseException:
271            self.die("could not retrieve manifest path from west configuration")
272
273        topdir = Path(self.topdir)
274
275        if args.src_module is not None:
276            mod_path = self.get_module_path(args.src_module)
277            if mod_path is None:
278                self.die(f'Source module "{args.src_module}" not found')
279            if args.patch_base is not None and args.patch_base.is_absolute():
280                self.die("patch-base must not be an absolute path in combination with src-module")
281            if args.patch_yml is not None and args.patch_yml.is_absolute():
282                self.die("patch-yml must not be an absolute path in combination with src-module")
283            manifest_dir = topdir / mod_path
284        else:
285            manifest_dir = topdir / manifest_path
286
287        if args.patch_base is None:
288            args.patch_base = manifest_dir / WEST_PATCH_BASE
289        if not args.patch_base.is_absolute():
290            args.patch_base = manifest_dir / args.patch_base
291
292        if args.patch_yml is None:
293            args.patch_yml = manifest_dir / WEST_PATCH_YAML
294        elif not args.patch_yml.is_absolute():
295            args.patch_yml = manifest_dir / args.patch_yml
296
297        if args.west_workspace is None:
298            args.west_workspace = topdir
299        elif not args.west_workspace.is_absolute():
300            args.west_workspace = topdir / args.west_workspace
301
302        if args.dst_modules is not None:
303            args.dst_modules = [self.get_module_path(m) for m in args.dst_modules]
304
305    def load_yml(self, args, allow_missing):
306        if not os.path.isfile(args.patch_yml):
307            if not allow_missing:
308                self.inf(f"no patches to apply: {args.patch_yml} not found")
309                return None
310
311            # Return the schema defaults
312            return pykwalify.core.Core(source_data={}, schema_data=patches_schema).validate()
313
314        try:
315            with open(args.patch_yml) as f:
316                yml = yaml.load(f, Loader=SafeLoader)
317            return pykwalify.core.Core(source_data=yml, schema_data=patches_schema).validate()
318        except (yaml.YAMLError, pykwalify.errors.SchemaError) as e:
319            self.die(f"ERROR: Malformed yaml {args.patch_yml}: {e}")
320
321    def do_run(self, args, _):
322        self.filter_args(args)
323
324        west_config = Path(args.west_workspace) / ".west" / "config"
325        if not os.path.isfile(west_config):
326            self.die(f"{args.west_workspace} is not a valid west workspace")
327
328        yml = self.load_yml(args, args.subcommand in ["gh-fetch"])
329        if yml is None:
330            return
331
332        if not args.subcommand:
333            args.subcommand = "list"
334
335        method = {
336            "apply": self.apply,
337            "clean": self.clean,
338            "list": self.list,
339            "gh-fetch": self.gh_fetch,
340        }
341
342        method[args.subcommand](args, yml, args.dst_modules)
343
344    def apply(self, args, yml, dst_mods=None):
345        patches = yml.get("patches", [])
346        if not patches:
347            return
348
349        patch_count = 0
350        failed_patch = None
351        patched_mods = set()
352
353        for patch_info in patches:
354            mod = self.get_module_path(patch_info["module"])
355            if mod is None:
356                continue
357
358            if dst_mods and mod not in dst_mods:
359                continue
360
361            pth = patch_info["path"]
362            patch_path = os.path.realpath(Path(args.patch_base) / pth)
363
364            apply_cmd = patch_info["apply-command"]
365            apply_cmd_list = shlex.split(apply_cmd)
366
367            self.dbg(f"reading patch file {pth}")
368            patch_file_data = None
369
370            try:
371                with open(patch_path, "rb") as pf:
372                    patch_file_data = pf.read()
373            except Exception as e:
374                self.err(f"failed to read {pth}: {e}")
375                failed_patch = pth
376                break
377
378            self.dbg("checking patch integrity... ", end="")
379            expect_sha256 = patch_info["sha256sum"]
380            hasher = hashlib.sha256()
381            hasher.update(patch_file_data)
382            actual_sha256 = hasher.hexdigest()
383            if actual_sha256 != expect_sha256:
384                self.dbg("FAIL")
385                self.err(
386                    f"sha256 mismatch for {pth}:\n"
387                    f"expect: {expect_sha256}\n"
388                    f"actual: {actual_sha256}"
389                )
390                failed_patch = pth
391                break
392            self.dbg("OK")
393            patch_count += 1
394            patch_file_data = None
395
396            mod_path = Path(args.west_workspace) / mod
397            patched_mods.add(mod)
398
399            self.dbg(f"patching {mod}... ", end="")
400            apply_cmd += patch_path
401            apply_cmd_list.extend([patch_path])
402            proc = subprocess.run(apply_cmd_list, cwd=mod_path)
403            if proc.returncode:
404                self.dbg("FAIL")
405                self.err(proc.stderr)
406                failed_patch = pth
407                break
408            self.dbg("OK")
409
410        if not failed_patch:
411            self.inf(f"{patch_count} patches applied successfully \\o/")
412            return
413
414        if args.roll_back:
415            self.clean(args, yml, patched_mods)
416
417        self.die(f"failed to apply patch {failed_patch}")
418
419    def clean(self, args, yml, dst_mods=None):
420        clean_cmd = yml["clean-command"]
421        checkout_cmd = yml["checkout-command"]
422
423        if not clean_cmd and not checkout_cmd:
424            self.dbg("no clean or checkout commands specified")
425            return
426
427        clean_cmd_list = shlex.split(clean_cmd)
428        checkout_cmd_list = shlex.split(checkout_cmd)
429
430        for mod in yml.get("patches", []):
431            m = self.get_module_path(mod.get("module"))
432            if m is None:
433                continue
434            if dst_mods and m not in dst_mods:
435                continue
436            mod_path = Path(args.west_workspace) / m
437
438            try:
439                if checkout_cmd:
440                    self.dbg(f"Running '{checkout_cmd}' in {mod}.. ", end="")
441                    proc = subprocess.run(checkout_cmd_list, capture_output=True, cwd=mod_path)
442                    if proc.returncode:
443                        self.dbg("FAIL")
444                        self.err(f"{checkout_cmd} failed for {mod}\n{proc.stderr}")
445                    else:
446                        self.dbg("OK")
447
448                if clean_cmd:
449                    self.dbg(f"Running '{clean_cmd}' in {mod}.. ", end="")
450                    proc = subprocess.run(clean_cmd_list, capture_output=True, cwd=mod_path)
451                    if proc.returncode:
452                        self.dbg("FAIL")
453                        self.err(f"{clean_cmd} failed for {mod}\n{proc.stderr}")
454                    else:
455                        self.dbg("OK")
456
457            except Exception as e:
458                # If this fails for some reason, just log it and continue
459                self.err(f"failed to clean up {mod}: {e}")
460
461    def list(self, args, yml, dst_mods=None):
462        patches = yml.get("patches", [])
463        if not patches:
464            return
465
466        for patch_info in patches:
467            if dst_mods and self.get_module_path(patch_info["module"]) not in dst_mods:
468                continue
469            self.inf(patch_info)
470
471    def gh_fetch(self, args, yml, mods=None):
472        if mods:
473            self.die(
474                "Module filters are not available for the gh-fetch subcommand, "
475                "pass a single -m/--module argument after the subcommand."
476            )
477
478        try:
479            from github import Auth, Github
480        except ImportError:
481            self.die("PyGithub not found; can be installed with 'pip install PyGithub'")
482
483        gh = Github(auth=Auth.Token(args.tokenfile) if args.tokenfile else None)
484        pr = gh.get_repo(f"{args.owner}/{args.repo}").get_pull(args.pull_request)
485        args.patch_base.mkdir(parents=True, exist_ok=True)
486
487        if args.split_commits:
488            for cm in pr.get_commits():
489                subject = cm.commit.message.splitlines()[0]
490                filename = "-".join(filter(None, re.split("[^a-zA-Z0-9]+", subject))) + ".patch"
491
492                # No patch URL is provided by the API, but appending .patch to the HTML works too
493                urllib.request.urlretrieve(f"{cm.html_url}.patch", args.patch_base / filename)
494
495                patch_info = {
496                    "path": filename,
497                    "sha256sum": self.get_file_sha256sum(args.patch_base / filename),
498                    "module": str(args.module),
499                    "author": cm.commit.author.name or "Hidden",
500                    "email": cm.commit.author.email or "hidden@github.com",
501                    "date": cm.commit.author.date.strftime("%Y-%m-%d"),
502                    "upstreamable": True,
503                    "merge-pr": pr.html_url,
504                    "merge-status": pr.merged,
505                }
506
507                yml.setdefault("patches", []).append(patch_info)
508        else:
509            filename = "-".join(filter(None, re.split("[^a-zA-Z0-9]+", pr.title))) + ".patch"
510            urllib.request.urlretrieve(pr.patch_url, args.patch_base / filename)
511
512            patch_info = {
513                "path": filename,
514                "sha256sum": self.get_file_sha256sum(args.patch_base / filename),
515                "module": str(args.module),
516                "author": pr.user.name or "Hidden",
517                "email": pr.user.email or "hidden@github.com",
518                "date": pr.created_at.strftime("%Y-%m-%d"),
519                "upstreamable": True,
520                "merge-pr": pr.html_url,
521                "merge-status": pr.merged,
522            }
523
524            yml.setdefault("patches", []).append(patch_info)
525
526        args.patch_yml.parent.mkdir(parents=True, exist_ok=True)
527        with open(args.patch_yml, "w") as f:
528            yaml.dump(yml, f, Dumper=SafeDumper)
529
530    @staticmethod
531    def get_file_sha256sum(filename: Path) -> str:
532        with open(filename, "rb") as fp:
533            digest = hashlib.file_digest(fp, "sha256")
534
535        return digest.hexdigest()
536
537    def get_module_path(self, module_name_or_path):
538        if module_name_or_path is None:
539            return None
540
541        topdir = Path(self.topdir)
542
543        if Path(module_name_or_path).is_absolute():
544            if Path(module_name_or_path).is_dir():
545                return Path(module_name_or_path).resolve().relative_to(topdir)
546            return None
547
548        if (topdir / module_name_or_path).is_dir():
549            return Path(module_name_or_path)
550
551        all_modules = zephyr_module.parse_modules(ZEPHYR_BASE, self.manifest)
552        for m in all_modules:
553            if m.meta['name'] == module_name_or_path:
554                return Path(m.project).relative_to(topdir)
555
556        return None
557