1"""
2Utility to autogenerate Zephyr DT pinctrl files for all STM32 microcontrollers.
3
4Usage::
5
6    python3 genpinctrl.py -p /path/to/stm32-open-pin-data-repository
7                          [-o /path/to/output_dir]
8
9Copyright (c) 2020 Teslabs Engineering S.L.
10
11SPDX-License-Identifier: Apache-2.0
12"""
13
14import argparse
15from collections import OrderedDict
16import logging
17from pathlib import Path
18import re
19import shutil
20from subprocess import check_output, STDOUT, CalledProcessError
21import xml.etree.ElementTree as ET
22
23from jinja2 import Environment, FileSystemLoader
24import yaml
25
26
27logger = logging.getLogger(__name__)
28
29
30SCRIPT_DIR = Path(__file__).absolute().parent
31"""Script directory."""
32
33REPO_ROOT = SCRIPT_DIR / ".." / ".."
34"""Repository root (used for defaults)."""
35
36CONFIG_FILE = SCRIPT_DIR / "stm32-pinctrl-config.yaml"
37"""Configuration file."""
38
39CONFIG_F1_FILE = SCRIPT_DIR / "stm32f1-pinctrl-config.yaml"
40"""Configuration file for F1 series."""
41
42PINCTRL_TEMPLATE = "pinctrl-template.j2"
43"""pinctrl template file."""
44
45README_TEMPLATE = "readme-template.j2"
46"""Readme template file."""
47
48NS = "{http://dummy.com}"
49"""MCU XML namespace."""
50
51PINCTRL_ADDRESSES = {
52    "stm32c0": 0x50000000,
53    "stm32f0": 0x48000000,
54    "stm32f1": 0x40010800,
55    "stm32f2": 0x40020000,
56    "stm32f3": 0x48000000,
57    "stm32f4": 0x40020000,
58    "stm32f7": 0x40020000,
59    "stm32g0": 0x50000000,
60    "stm32g4": 0x48000000,
61    "stm32h5": 0x42020000,
62    "stm32h7": 0x58020000,
63    "stm32h7rs": 0x58020000,
64    "stm32l0": 0x50000000,
65    "stm32l1": 0x40020000,
66    "stm32l4": 0x48000000,
67    "stm32l5": 0x42020000,
68    "stm32mp1": 0x50002000,
69    "stm32n6": 0x56020000,
70    "stm32u0": 0x50000000,
71    "stm32u5": 0x42020000,
72    "stm32wba": 0x42020000,
73    "stm32wb": 0x48000000,
74    "stm32wb0": 0x48000000,
75    "stm32wl": 0x48000000,
76}
77"""pinctrl peripheral addresses for each family."""
78
79PIN_MODS = [
80    "_C",  # Pins with analog switch (H7)
81]
82"""Allowed pin modifiers"""
83
84
85class FamilyFilter():
86    def _prepare(self, filters, excluding: bool):
87        family_list = []
88
89        self.excluding = excluding
90
91        for filter in filters:
92            family_name = "STM32" + \
93                filter.upper().removeprefix("STM32")
94
95            family_list.append(family_name)
96
97        self.filtered_families_list = family_list
98
99    def __init__(self):
100        self.filtered_families_list = []
101
102    def set_filters(self, allow_filter, forbid_filter):
103        if allow_filter is not None:
104            self.filter_list = self._prepare(
105                allow_filter, False)
106        elif forbid_filter is not None:
107            self.filter_list = self._prepare(
108                forbid_filter, True)
109
110    def is_active(self) -> bool:
111        """Is the filter active?"""
112        return len(self.filtered_families_list) > 0
113
114    def should_skip_model(self, model: str) -> bool:
115        """Should processing of STM32 model be skipped?
116
117            model:
118                STM32 model string (any string that starts
119                with 'STM32yyy' where yyy is family code)
120        """
121        if not self.is_active():
122            return False
123
124        for family in self.filtered_families_list:
125            if model.startswith(family):
126                # Skip if we found and this is exclude list
127                return self.excluding
128        # Skip if not found and this is include list
129        return not self.excluding
130
131
132FAMILY_FILTER = FamilyFilter()
133"""STM32 family selection filter"""
134
135
136def validate_config_entry(entry, family):
137    """Validates pin configuration entry.
138
139    Args:
140        entry: Pin configuration entry.
141        family: STM32 family, e.g. "STM32F1".
142
143    Raises:
144        ValueError: If entry is not valid.
145    """
146
147    if not entry.get("name"):
148        raise ValueError("Missing entry name")
149
150    if not entry.get("match"):
151        raise ValueError(f"Missing entry match for {entry['name']}")
152
153    if family == "STM32F1":
154        if not entry.get("mode"):
155            raise ValueError(f"Missing entry mode for {entry['name']}")
156        if entry["mode"] not in ("analog", "input", "alternate"):
157            raise ValueError(f"Invalid mode for {entry['name']}: {entry['mode']}")
158    else:
159        if entry.get("mode"):
160            if entry["mode"] not in ("analog", "alternate"):
161                raise ValueError(f"Invalid mode for {entry['name']}: {entry['mode']}")
162
163    if entry.get("bias"):
164        if entry["bias"] not in ("disable", "pull-up", "pull-down"):
165            raise ValueError(f"Invalid bias for {entry['name']}: {entry['bias']}")
166
167        if (
168            family == "STM32F1"
169            and entry["mode"] != "input"
170            and entry["bias"] != "disable"
171        ):
172            raise ValueError(
173                f"Bias can only be set for input mode on F1 (entry: {entry['name']})"
174            )
175
176    if entry.get("drive"):
177        if entry["drive"] not in ("push-pull", "open-drain"):
178            raise ValueError(f"Invalid drive for {entry['name']}: {entry['drive']}")
179
180    if entry.get("slew-rate"):
181        if family == "STM32F1":
182            if entry["slew-rate"] not in (
183                "max-speed-10mhz",
184                "max-speed-2mhz",
185                "max-speed-50mhz",
186            ):
187                raise ValueError(
188                    f"Invalid slew rate for {entry['name']}: {entry['slew-rate']}"
189                )
190        else:
191            if entry["slew-rate"] not in (
192                "low-speed",
193                "medium-speed",
194                "high-speed",
195                "very-high-speed",
196            ):
197                raise ValueError(
198                    f"Invalid slew rate for {entry['name']}: {entry['slew-rate']}"
199                )
200
201
202def format_mode(mode, af):
203    """Format mode for FT (non-F1 series).
204
205    Args:
206        mode: Operation mode (analog, alternate).
207        af: Alternate function ("analog" or AF number).
208
209    Returns:
210        DT AF definition.
211    """
212
213    if mode == "analog":
214        return "ANALOG"
215    elif mode == "alternate":
216        return f"AF{af:d}"
217
218    raise ValueError(f"Unsupported mode: {mode}")
219
220
221def format_mode_f1(mode):
222    """Format mode for DT (F1 series).
223
224    Args:
225        mode: Mode (analog, input, alternate).
226
227    Returns:
228        DT mode definition.
229    """
230
231    if mode == "analog":
232        return "ANALOG"
233    elif mode == "input":
234        return "GPIO_IN"
235    elif mode == "alternate":
236        return "ALTERNATE"
237
238    raise ValueError(f"Unsupported mode: {mode}")
239
240
241def format_remap(remap):
242    """Format remap value for DT.
243
244    Args:
245        remap: Remap definition.
246
247    Returns:
248        DT remap definition.
249    """
250
251    if remap == 0 or remap is None:
252        return "NO_REMAP"
253    else:
254        return remap
255
256
257def format_remap_name(remap):
258    """Format remap value for DT node name
259
260    Args:
261        remap: Remap definition.
262
263    Returns:
264        DT remap definition in lower caps
265    """
266
267    if remap == 0 or remap is None:
268        return ""
269    elif "REMAP0" in remap:
270        return ""
271    elif "REMAP1" in remap:
272        return "_remap1"
273    elif "REMAP2" in remap:
274        return "_remap2"
275    elif "REMAP3" in remap:
276        return "_remap3"
277    else:
278        return ""
279
280
281def get_gpio_ip_afs(data_path):
282    """Obtain all GPIO IP alternate functions.
283
284    Example output::
285
286        {
287            "STM32L4P_gpio_v1_0": {
288                "PA2": {
289                    "ADC1_IN2": "analog",
290                    "EVENTOUT": 15,
291                    "LPUART1_TX": 8,
292                    ...
293                },
294                ...
295            },
296            "STM32F103x4_gpio_v1_0": {
297                "PB3": {
298                    "ADC1_IN2": "analog",
299                    "EVENTOUT": [0],
300                    "LPUART1_TX": [0, 1],
301                    ...
302                },
303                ...
304            },
305            ...
306        }
307
308    Notes:
309        F1 series AF number corresponds to remap numbers.
310
311    Args:
312        data_path: STM32 Open Pin Data repository path.
313
314    Returns:
315        Dictionary of alternate functions.
316    """
317
318    ip_path = data_path / "mcu" / "IP"
319    if not ip_path.exists():
320        raise FileNotFoundError(f"IP DB folder '{ip_path}' does not exist")
321
322    results = dict()
323
324    for gpio_file in ip_path.glob("GPIO-*_Modes.xml"):
325        m = re.search(r"GPIO-(.*)_Modes.xml", gpio_file.name)
326        gpio_ip = m.group(1)
327
328        if FAMILY_FILTER.should_skip_model(gpio_ip):
329            continue
330
331        gpio_ip_entries = dict()
332        results[gpio_ip] = gpio_ip_entries
333
334        gpio_tree = ET.parse(gpio_file)
335        gpio_root = gpio_tree.getroot()
336
337        for pin in gpio_root.findall(NS + "GPIO_Pin"):
338            pin_name = pin.get("Name")
339
340            pin_entries = dict()
341            gpio_ip_entries[pin_name] = pin_entries
342
343            for signal in pin.findall(NS + "PinSignal"):
344                signal_name = signal.get("Name")
345
346                if "STM32F1" in gpio_ip:
347                    remap_blocks = signal.findall(NS + "RemapBlock")
348                    if remap_blocks is None:
349                        logger.error(
350                            f"Missing remaps for {signal_name} (ip: {gpio_ip})"
351                        )
352                        continue
353
354                    for remap_block in remap_blocks:
355                        name = remap_block.get("Name")
356                        m = re.search(r"^[A-Z0-9]+_REMAP(\d+)", name)
357                        if not m:
358                            logger.error(
359                                f"Unexpected remap format: {name} (ip: {gpio_ip})"
360                            )
361                            continue
362
363                        if signal_name not in pin_entries:
364                            pin_entries[signal_name] = list()
365                        pin_entries[signal_name].append(name)
366                else:
367                    param = signal.find(NS + "SpecificParameter")
368                    if param is None:
369                        logger.error(
370                            f"Missing parameters for {signal_name} (ip: {gpio_ip})"
371                        )
372                        continue
373
374                    value = param.find(NS + "PossibleValue")
375                    if value is None:
376                        logger.error(
377                            f"Missing signal value for {signal_name} (ip: {gpio_ip})"
378                        )
379                        continue
380
381                    m = re.search(r"^GPIO_AF(\d+)_[A-Z0-9]+", value.text)
382                    if not m:
383                        logger.error(
384                            f"Unexpected AF format: {value.text} (ip: {gpio_ip})"
385                        )
386                        continue
387
388                    af_n = int(m.group(1))
389                    pin_entries[signal_name] = af_n
390
391    return results
392
393
394def get_mcu_signals(data_path, gpio_ip_afs):
395    """Obtain all MCU signals.
396
397    Example output::
398
399        {
400            "STM32WB": [
401                {
402                    "name": "STM32WB30CEUx"
403                    "pins: [
404                        {
405                            "port": "a",
406                            "number": 0,
407                            "mod": "",
408                            "signals" : [
409                                {
410                                    "name": "ADC1_IN5",
411                                    "af": None,
412                                },
413                                {
414                                    "name": "UART1_TX",
415                                    "af": 3,
416                                },
417                                ...
418                            ]
419                        },
420                        ...
421                    ]
422                },
423                ...
424            ]
425        }
426
427    Args:
428        data_path: STM32 Open Pin Data repository path.
429        gpio_ip_afs: GPIO IP alternate functions.
430
431    Returns:
432        Dictionary with all MCU signals.
433    """
434
435    mcus_path = data_path / "mcu"
436    if not mcus_path.exists():
437        raise FileNotFoundError(f"MCU DB folder '{mcus_path}' does not exist")
438
439    results = dict()
440
441    for mcu_file in mcus_path.glob("STM32*.xml"):
442        if FAMILY_FILTER.should_skip_model(mcu_file.name):
443            continue
444
445        mcu_tree = ET.parse(mcu_file)
446        mcu_root = mcu_tree.getroot()
447
448        # obtain family, reference and GPIO IP
449        family = mcu_root.get("Family").replace("+", "")
450        ref = mcu_root.get("RefName")
451
452        gpio_ip_version = None
453        for ip in mcu_root.findall(NS + "IP"):
454            if ip.get("Name") == "GPIO":
455                gpio_ip_version = ip.get("Version")
456                break
457
458        if not gpio_ip_version:
459            logger.error(f"GPIO IP version not specified (mcu: {mcu_file})")
460            continue
461
462        if gpio_ip_version not in gpio_ip_afs:
463            logger.error(f"GPIO IP version {gpio_ip_version} not available")
464            continue
465
466        gpio_ip = gpio_ip_afs[gpio_ip_version]
467
468        # create reference entry on its family
469        if family not in results:
470            family_entries = list()
471            results[family] = family_entries
472        else:
473            family_entries = results[family]
474
475        pin_entries = list()
476        family_entries.append({"name": ref, "pins": pin_entries})
477
478        # process all pins
479        for pin in mcu_root.findall(NS + "Pin"):
480            if pin.get("Type") != "I/O":
481                continue
482
483            pin_name = pin.get("Name")
484
485            # skip duplicate remappable entries in some C0 and G0 files
486            if family in ("STM32C0", "STM32G0") and pin_name in ("PA9", "PA10"):
487                continue
488
489            # obtain pin port (A, B, ...), number (0, 1, ...) and modifier
490            m = re.search(r"^P([A-Z])(\d+)(.*)$", pin_name)
491            if not m:
492                continue
493
494            pin_port = m.group(1).lower()
495            pin_number = int(m.group(2))
496            pin_mod = m.group(3).lower() if m.group(3) in PIN_MODS else ""
497
498            if pin_name not in gpio_ip:
499                continue
500
501            pin_afs = gpio_ip[pin_name]
502
503            pin_signals = list()
504            pin_entries.append(
505                {
506                    "port": pin_port,
507                    "pin": pin_number,
508                    "mod": pin_mod,
509                    "signals": pin_signals,
510                }
511            )
512
513            # process all pin signals
514            for signal in pin.findall(NS + "Signal"):
515                if signal.get("Name") == "GPIO":
516                    if signal.get("IOModes") and "Analog" in signal.get("IOModes"):
517                        pin_signals.append({"name": "ANALOG", "af": None})
518                    continue
519
520                signal_name = signal.get("Name")
521                if signal_name is None:
522                    continue
523
524                if signal_name in pin_afs:
525                    pin_af = pin_afs[signal_name]
526                    if not isinstance(pin_af, list):
527                        pin_af = [pin_af]
528                    found_afs = pin_af
529                # STM32F1: assume NO_REMAP (af=0) if signal is not listed in pin_afs
530                elif family == "STM32F1":
531                    found_afs = [0]
532                # Non STM32F1: No alternate function found, mode is analog
533                else:
534                    found_afs = [None]
535
536                for af in found_afs:
537                    pin_signals.append({"name": signal_name, "af": af})
538
539    return results
540
541
542def detect_xml_namespace(data_path: Path):
543    """
544    Attempt to detect the XML namespace used in the pindata files automatically.
545    This removes the need to modify this file when using pin data from sources
546    other than the official ST repository, which may use a different xmlns.
547    """
548    global NS
549
550    mcus_path = data_path / "mcu"
551    try:
552        sampled_file = next(mcus_path.glob("STM32*.xml"))
553    except StopIteration:
554        # No STM32*.xml file found. Log a warning but continue script execution.
555        # If this really isn't a pindata folder, something else will panic later on.
556        logger.warn(f"No STM32*.xml found in {data_path!s} - XMLNS detection skipped")
557        return
558
559    with open(sampled_file, "r") as fd:
560        line = "<dummy>"
561        xmlns = None
562        while len(line) > 0:
563            line = fd.readline().removeprefix("<").removesuffix(">\n")
564
565            # '<Mcu ...>' tag sets XML namespace
566            if line.startswith("Mcu"):
567                # Find the XML namespace in tag elements
568                for e in line.split():
569                    if e.startswith("xmlns="):
570                        xmlns = e
571                        break
572                break
573
574        if xmlns is None:
575            logger.info(f"Could not determine XML namespace from {sampled_file}")
576            return
577        else:
578            xml_namespace_url = xmlns.removeprefix('xmlns="').removesuffix('"')
579            NS = "{" + xml_namespace_url + "}"
580
581        logger.info(f"Using {NS} as XML namespace.")
582
583
584def main(data_path, output):
585    """Entry point.
586
587    Args:
588        data_path: STM32 Open Pin Data repository path.
589        output: Output directory.
590    """
591
592    with open(CONFIG_FILE) as f:
593        config = yaml.load(f, Loader=yaml.Loader)
594
595    with open(CONFIG_F1_FILE) as f:
596        config_f1 = yaml.load(f, Loader=yaml.Loader)
597
598    env = Environment(
599        trim_blocks=True, lstrip_blocks=True, loader=FileSystemLoader(SCRIPT_DIR)
600    )
601    env.filters["format_mode"] = format_mode
602    env.filters["format_mode_f1"] = format_mode_f1
603    env.filters["format_remap"] = format_remap
604    env.filters["format_remap_name"] = format_remap_name
605    pinctrl_template = env.get_template(PINCTRL_TEMPLATE)
606    readme_template = env.get_template(README_TEMPLATE)
607
608    detect_xml_namespace(data_path)
609
610    gpio_ip_afs = get_gpio_ip_afs(data_path)
611    mcu_signals = get_mcu_signals(data_path, gpio_ip_afs)
612
613    # erase output if we're about to generate for all families
614    if output.exists() and not FAMILY_FILTER.is_active():
615        shutil.rmtree(output)
616        output.mkdir(parents=True)
617
618    for family, refs in mcu_signals.items():
619        # obtain family pinctrl address
620        pinctrl_addr = PINCTRL_ADDRESSES.get(family.lower())
621        if not pinctrl_addr:
622            logger.warning(f"Skipping unsupported family {family}.")
623            continue
624        else:
625            logger.info(f"Processing family {family}...")
626
627        # create directory for each family
628        family_dir = output / "st" / family.lower()[5:]
629        if not family_dir.exists():
630            family_dir.mkdir(parents=True)
631
632        # process each reference
633        for ref in refs:
634            entries = dict()
635
636            # process each pin in the current reference
637            for pin in ref["pins"]:
638                # process each pin available signal (matched against regex)
639                for signal in pin["signals"]:
640                    if family == "STM32F1":
641                        selected_config = config_f1
642                    else:
643                        selected_config = config
644
645                    for af in selected_config:
646                        validate_config_entry(af, family)
647
648                        m = re.search(af["match"], signal["name"])
649                        if not m:
650                            continue
651
652                        if af["name"] not in entries:
653                            entries[af["name"]] = list()
654
655                        # Define the signal mode using, by priority order:
656                        # 1- the config mode (ie: "af["mode"]")
657                        # 2- the inferred mode (an alternate function was found)
658                        if af.get("mode") or family == "STM32F1":
659                            signal["mode"] = af["mode"]
660                        else:
661                            if signal["af"] is not None:
662                                signal["mode"] = "alternate"
663                            else:
664                                signal["mode"] = "analog"
665
666                        entries[af["name"]].append(
667                            {
668                                "port": pin["port"],
669                                "pin": pin["pin"],
670                                "mod": pin["mod"],
671                                "signal": signal["name"].lower().replace("-", "_"),
672                                "af": signal["af"],
673                                "mode": signal["mode"],
674                                "drive": af.get("drive"),
675                                "bias": af.get("bias"),
676                                "slew-rate": af.get("slew-rate"),
677                                "variant": af.get("variant"),
678                            }
679                        )
680
681            if not entries:
682                continue
683
684            # sort entries by group name
685            entries = OrderedDict(sorted(entries.items(), key=lambda kv: kv[0]))
686
687            # sort entries in each group by signal, port, pin
688            for group in entries:
689                entries[group] = sorted(
690                    entries[group],
691                    key=lambda entry: (
692                        str(entry["signal"]).split("_")[0][-1],
693                        entry["port"],
694                        entry["pin"],
695                    ),
696                )
697
698            # write pinctrl file
699            pinctrl_filename = f"{ref['name'].lower()}-pinctrl.dtsi"
700            rendered = ""
701            try:
702                rendered = pinctrl_template.render(
703                    family=family, pinctrl_addr=pinctrl_addr, entries=entries
704                )
705            except Exception:
706                logger.error(f"Skipping '{pinctrl_filename}' (rendering failed)")
707                continue
708
709            with open(family_dir / pinctrl_filename, "w") as f:
710                f.write(rendered)
711
712    # write readme file
713    try:
714        commit_raw = check_output(
715            ["git", "rev-parse", "HEAD"], cwd=data_path, stderr=STDOUT)
716        commit = commit_raw.decode("utf-8").strip()
717    except CalledProcessError:
718        commit = "<unknown commit>"
719    with open(output / "README.rst", "w") as f:
720        f.write(readme_template.render(commit=commit))
721
722
723if __name__ == "__main__":
724    parser = argparse.ArgumentParser()
725    parser.add_argument(
726        "-p",
727        "--data-path",
728        type=Path,
729        required=True,
730        help="Path to STM32 Open Pin Data repository",
731    )
732    parser.add_argument(
733        "-o",
734        "--output",
735        type=Path,
736        default=REPO_ROOT / "dts",
737        help="Output directory",
738    )
739    parser.add_argument(
740        "-v",
741        "--verbose",
742        action="store_true",
743        help="Make script verbose"
744    )
745    filter_group = parser.add_mutually_exclusive_group()
746    filter_group.add_argument(
747        "-f",
748        "--only-family",
749        type=str,
750        action="append",
751        help="process only specified STM32 family "
752        "(can be specified multiple times)"
753    )
754    filter_group.add_argument(
755        "-nf",
756        "--not-family",
757        type=str,
758        action="append",
759        help="don't process specified STM32 family "
760        "(can be specified multiple times)"
761    )
762    args = parser.parse_args()
763
764    logger.setLevel(logging.INFO if args.verbose else logging.WARN)
765    logger.addHandler(logging.StreamHandler())
766
767    FAMILY_FILTER.set_filters(args.only_family, args.not_family)
768
769    main(args.data_path, args.output)
770