1#!/usr/bin/env python3
2#
3# Copyright (c) 2020 Nordic Semiconductor ASA
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7
8from regex import compile, S, M
9from pprint import pformat, pprint
10from os import path, linesep, makedirs
11from collections import defaultdict, namedtuple
12from collections.abc import Hashable
13from typing import NamedTuple
14from argparse import ArgumentParser, ArgumentTypeError, RawDescriptionHelpFormatter, FileType
15from datetime import datetime
16from copy import copy
17from itertools import tee, chain
18from cbor2 import (loads, dumps, CBORTag, load, CBORDecodeValueError, CBORDecodeEOF, undefined,
19                   CBORSimpleValue)
20from yaml import safe_load as yaml_load, dump as yaml_dump
21from json import loads as json_load, dumps as json_dump
22from io import BytesIO
23from subprocess import Popen, PIPE
24from pathlib import Path, PurePath, PurePosixPath
25from shutil import copyfile
26import sys
27from site import USER_BASE
28from textwrap import wrap, indent
29from importlib.metadata import version
30
31regex_cache = {}
32indentation = "\t"
33newl_ind = "\n" + indentation
34
35SCRIPT_PATH = Path(__file__).absolute().parent
36PACKAGE_PATH = Path(__file__).absolute().parents[1]
37PRELUDE_PATH = SCRIPT_PATH / "prelude.cddl"
38VERSION_PATH = SCRIPT_PATH / "VERSION"
39C_SRC_PATH = PACKAGE_PATH / "src"
40C_INCLUDE_PATH = PACKAGE_PATH / "include"
41
42__version__ = VERSION_PATH.read_text(encoding="utf-8").strip()
43
44UINT8_MAX = 0xFF
45UINT16_MAX = 0xFFFF
46UINT32_MAX = 0xFFFFFFFF
47UINT64_MAX = 0xFFFFFFFFFFFFFFFF
48
49INT8_MAX = 0x7F
50INT16_MAX = 0x7FFF
51INT32_MAX = 0x7FFFFFFF
52INT64_MAX = 0x7FFFFFFFFFFFFFFF
53
54INT8_MIN = -0x80
55INT16_MIN = -0x8000
56INT32_MIN = -0x80000000
57INT64_MIN = -0x8000000000000000
58
59
60def getrp(pattern, flags=0):
61    """Get a compiled regex pattern from the cache. Add it to the cache if not present."""
62    pattern_key = pattern if not flags else (pattern, flags)
63    if pattern_key not in regex_cache:
64        regex_cache[pattern_key] = compile(pattern, flags)
65    return regex_cache[pattern_key]
66
67
68def sizeof(num):
69    """Size of "additional" field if num is encoded as int"""
70    if num <= 23:
71        return 0
72    elif num <= UINT8_MAX:
73        return 1
74    elif num <= UINT16_MAX:
75        return 2
76    elif num <= UINT32_MAX:
77        return 4
78    elif num <= UINT64_MAX:
79        return 8
80    else:
81        raise ValueError("Number too large (more than 64 bits).")
82
83
84def verbose_print(verbose_flag, *things):
85    """Print only if verbose"""
86    if verbose_flag:
87        print(*things)
88
89
90def verbose_pprint(verbose_flag, *things):
91    """Pretty print only if verbose"""
92    if verbose_flag:
93        pprint(*things)
94
95
96global_counter = 0
97
98
99def counter(reset=False):
100    """Retrieve a unique id."""
101    global global_counter
102    if reset:
103        global_counter = 0
104        return global_counter
105    global_counter += 1
106    return global_counter
107
108
109def list_replace_if_not_null(lst, i, r):
110    """Replace an element in a list or tuple and return the list."""
111    if lst[i] == "NULL":
112        return lst
113    if isinstance(lst, tuple):
114        convert = tuple
115        lst = list(lst)
116    else:
117        assert isinstance(lst, list)
118        convert = list
119    lst[i] = r
120    return convert(lst)
121
122
123def val_or_null(value, var_name):
124    """Return a code snippet that assigns to and the returns a variable
125
126    Return a code snippet that assigns the value to a variable var_name and
127    returns pointer to the variable, or returns NULL if the value is None.
128     """
129    return "(%s = %d, &%s)" % (var_name, value, var_name) if value is not None else "NULL"
130
131
132def tmp_str_or_null(value):
133    """Assign the min_value variable."""
134    value_str = f'"{value}"' if value is not None else 'NULL'
135    len_str = f"""sizeof({f'"{value}"'}) - 1, &tmp_str)"""
136    return f"(tmp_str.value = (uint8_t *){value_str}, tmp_str.len = {len_str}"
137
138
139def min_bool_or_null(value):
140    """Assign the max_value variable."""
141    return f"(&(bool){{{int(value)}}})"
142
143
144def deref_if_not_null(access):
145    return access if access == "NULL" else "&" + access
146
147
148def xcode_args(res, *sargs):
149    """Return an argument list for a function call to a encoder/decoder function."""
150    if len(sargs) > 0:
151        return "state, %s, %s, %s" % (
152            "&(%s)" % res if res != "NULL" else res, sargs[0], sargs[1])
153    else:
154        return "state, %s" % (
155            "(%s)" % res if res != "NULL" else res)
156
157
158def xcode_statement(func, *sargs, **kwargs):
159    """Return the code that calls a encoder/decoder function with a given arguments."""
160    if func is None:
161        return "1"
162    return "(%s(%s))" % (func, xcode_args(*sargs, **kwargs))
163
164
165def add_semicolon(decl):
166    if len(decl) != 0 and decl[-1][-1] != ";":
167        decl[-1] += ";"
168    return decl
169
170
171def struct_ptr_name(mode):
172    """Return the name of the struct argument for a given mode."""
173    return "result" if mode == "decode" else "input"
174
175
176def ternary_if_chain(access, names, xcode_strings):
177    return "((%s == %s) ? %s%s: %s)" % (
178        access,
179        names[0],
180        xcode_strings[0],
181        newl_ind,
182        ternary_if_chain(access, names[1:], xcode_strings[1:]) if len(names) > 1 else "false")
183
184
185val_conversions = {
186    (2**64) - 1: "UINT64_MAX",
187    (2**63) - 1: "INT64_MAX",
188    (2**32) - 1: "UINT32_MAX",
189    (2**31) - 1: "INT32_MAX",
190    (2**16) - 1: "UINT16_MAX",
191    (2**15) - 1: "INT16_MAX",
192    (2**8) - 1: "UINT8_MAX",
193    (2**7) - 1: "INT8_MAX",
194    -(2**63): "INT64_MIN",
195    -(2**31): "INT32_MIN",
196    -(2**15): "INT16_MIN",
197    -(2**7): "INT8_MIN",
198}
199
200
201def val_to_str(val):
202    if isinstance(val, bool):
203        return str(val).lower()
204    elif isinstance(val, Hashable) and val in val_conversions:
205        return val_conversions[val]
206    return str(val)
207
208
209class CddlParser:
210    """Class for parsing CDDL.
211
212    One instance represents one CBOR data item, with a few caveats:
213    - For repeated data, one instance represents all repetitions.
214    - For "OTHER" types, one instance points to another type definition.
215    - For "GROUP" and "UNION" types, there is no separate data item for the instance.
216    """
217    def __init__(self, default_max_qty, my_types, my_control_groups, base_name=None,
218                 short_names=False, base_stem=''):
219        self.id_prefix = "temp_" + str(counter())
220        self.id_num = None  # Unique ID number. Only populated if needed.
221        # The value of the data item. Has different meaning for different
222        # types.
223        self.value = None
224        self.max_value = None  # Maximum value. Only used for numbers and bools.
225        self.min_value = None  # Minimum value. Only used for numbers and bools.
226        # The readable label associated with the element in the CDDL.
227        self.label = None
228        self.min_qty = 1  # The minimum number of times this element is repeated.
229        self.max_qty = 1  # The maximum number of times this element is repeated.
230        # The size of the element. Only used for integers, byte strings, and
231        # text strings.
232        self.size = None
233        self.min_size = None  # Minimum size.
234        self.max_size = None  # Maximum size.
235        # Key element. Only for children of "MAP" elements. self.key is of the
236        # same class as self.
237        self.key = None
238        # The element specified via.cbor or.cborseq(only for byte
239        # strings).self.cbor is of the same class as self.
240        self.cbor = None
241        # Any tags (type 6) to precede the element.
242        self.tags = []
243        # The CDDL string used to determine the min_qty and max_qty. Not used after
244        # min_qty and max_qty are determined.
245        self.quantifier = None
246        # Sockets are types starting with "$" or "$$". Do not fail if these aren't defined.
247        self.is_socket = False
248        # If the type has a ".bits <group_name>", this will contain <group_name> which can be looked
249        # up in my_control_groups.
250        self.bits = None
251        # The "type" of the element. This follows the CBOR types loosely, but are more related to
252        # CDDL concepts. The possible types are "INT", "UINT", "NINT", "FLOAT", "BSTR", "TSTR",
253        # "BOOL", "NIL", "UNDEF", "LIST", "MAP","GROUP", "UNION" and "OTHER". "OTHER" represents a
254        # CDDL type defined with '='.
255        self.type = None
256        self.match_str = ""
257        self.errors = list()
258
259        self.my_types = my_types
260        self.my_control_groups = my_control_groups
261        self.default_max_qty = default_max_qty  # args.default_max_qty
262        self.base_name = base_name  # Used as default for self.get_base_name()
263        # Stem which can be used when generating an id.
264        self.base_stem = base_stem.replace("-", "_")
265        self.short_names = short_names
266
267        if type(self) not in type(self).cddl_regexes:
268            self.cddl_regexes_init()
269
270    @classmethod
271    def from_cddl(cddl_class, cddl_string, default_max_qty, *args, **kwargs):
272        my_types = dict()
273
274        type_strings = cddl_class.get_types(cddl_string)
275        # Separate type_strings as keys in two dicts, one dict for strings that start with &( which
276        # are special control operators for .bits, and one dict for all the regular types.
277        my_types = \
278            {my_type: None for my_type, val in type_strings.items() if not val.startswith("&(")}
279        my_control_groups = \
280            {my_cg: None for my_cg, val in type_strings.items() if val.startswith("&(")}
281
282        # Parse the definitions, replacing the each string with a
283        # CodeGenerator instance.
284        for my_type, cddl_string in type_strings.items():
285            parsed = cddl_class(*args, default_max_qty, my_types, my_control_groups, **kwargs,
286                                base_stem=my_type)
287            parsed.get_value(cddl_string.replace("\n", " ").lstrip("&"))
288            parsed = parsed.flatten()[0]
289            if my_type in my_types:
290                my_types[my_type] = parsed
291            elif my_type in my_control_groups:
292                my_control_groups[my_type] = parsed
293
294        counter(True)
295
296        # post_validate all the definitions.
297        for my_type in my_types:
298            my_types[my_type].set_id_prefix()
299            my_types[my_type].post_validate()
300            my_types[my_type].set_base_names()
301        for my_control_group in my_control_groups:
302            my_control_groups[my_control_group].set_id_prefix()
303            my_control_groups[my_control_group].post_validate_control_group()
304
305        return CddlTypes(my_types, my_control_groups)
306
307    @staticmethod
308    def strip_comments(instr):
309        """Strip CDDL comments (';') from the string."""
310        return getrp(r"\;.*?(\n|$)").sub('', instr)
311
312    @staticmethod
313    def resolve_backslashes(instr):
314        """Replace escaped newlines with spaces."""
315        return getrp(r"\\\n").sub(" ", instr)
316
317    @classmethod
318    def get_types(cls, cddl_string):
319        """Returns a dict containing multiple typename=>string"""
320        instr = cls.strip_comments(cddl_string)
321        instr = cls.resolve_backslashes(instr)
322        type_regex = \
323            r"(\s*?\$?\$?([\w-]+)\s*(\/{0,2})=\s*(.*?)(?=(\Z|\s*\$?\$?[\w-]+\s*\/{0,2}=(?!\>))))"
324        result = defaultdict(lambda: "")
325        types = [
326            (key, value, slashes)
327            for (_1, key, slashes, value, _2) in getrp(type_regex, S | M).findall(instr)]
328        for key, value, slashes in types:
329            if slashes:
330                result[key] += slashes
331                result[key] += value
332                result[key] = result[key].lstrip(slashes)  # strip from front
333            else:
334                if key in result:
335                    raise ValueError(f"Duplicate CDDL type found: {key}")
336                result[key] = value
337        return dict(result)
338
339    backslash_quotation_mark = r'\"'
340
341    def generate_base_name(self):
342        """Generate a (hopefully) unique and descriptive name"""
343        byte_multi = (8 if self.type in ["INT", "UINT", "NINT", "FLOAT"] else 1)
344
345        # The first non-None entry is used:
346        raw_name = ((
347            # The label is the default if present:
348            self.label
349            # Name a key/value pair by its key type or string value:
350            or (self.key.value if self.key and self.key.type in ["TSTR", "OTHER"] else None)
351            # Name a string by its expected value:
352            or (f"{self.value.replace(self.backslash_quotation_mark, '')}_{self.type.lower()}"
353                if self.type == "TSTR" and self.value is not None else None)
354            # Name an integer by its expected value:
355            or (f"{self.type.lower()}{abs(self.value)}"
356                if self.type in ["INT", "UINT", "NINT"] and self.value is not None else None)
357            # Name a type by its type name
358            or (next((key for key, value in self.my_types.items() if value == self), None))
359            # Name a control group by its name
360            or (next((key for key, value in self.my_control_groups.items() if value == self), None))
361            # Name an instance by its type:
362            or (self.value + "_m" if self.type == "OTHER" else None)
363            # Name a list by its first element:
364            or (self.value[0].get_base_name() + "_l"
365                if self.type in ["LIST", "GROUP"] and self.value else None)
366            # Name a cbor-encoded bstr by its expected cbor contents:
367            or ((self.cbor.value + "_bstr")
368                if self.cbor and self.cbor.type in ["TSTR", "OTHER"] else None)
369            # Name a key value pair by its key (regardless of the key type)
370            or ((self.key.generate_base_name() + self.type.lower()) if self.key else None)
371            # Name an element by its minimum/maximum "size" (if the min == the max)
372            or (f"{self.type.lower()}{self.min_size * byte_multi}"
373                if (self.min_size is not None) and self.min_size == self.max_size else None)
374            # Name an element by its minimum/maximum "size" (if the min != the max)
375            or (f"{self.type.lower()}{self.min_size * byte_multi}-{self.max_size * byte_multi}"
376                if (self.min_size is not None) and (self.max_size is not None) else None)
377            # Name an element by its type.
378            or self.type.lower()).replace("-", "_"))
379
380        # Make the name compatible with C variable names
381        # (don't start with a digit, don't use accented letters or symbols other than '_')
382        name_regex = getrp(r'[a-zA-Z_][a-zA-Z\d_]*')
383        if name_regex.fullmatch(raw_name) is None:
384            latinized_name = getrp(r'[^a-zA-Z\d_]').sub("", raw_name)
385            if name_regex.fullmatch(latinized_name) is None:
386                # Add '_' if name starts with a digit or is empty after removing accented chars.
387                latinized_name = "_" + latinized_name
388            assert name_regex.fullmatch(latinized_name) is not None, \
389                f"Couldn't make '{raw_name}' valid. '{latinized_name}' is invalid."
390            return latinized_name
391        return raw_name
392
393    def get_base_name(self):
394        """Base name used for functions, variables, and typedefs."""
395        if not self.base_name:
396            self.set_base_name(self.generate_base_name())
397        return self.base_name
398
399    def set_base_name(self, base_name):
400        """Set an explicit base name for this element."""
401        self.base_name = base_name.replace("-", "_")
402
403    def set_base_names(self):
404        """Recursively set the base names of this element's children, keys, and cbor elements."""
405        if self.cbor:
406            self.cbor.set_base_name(self.var_name().strip("_") + "_cbor")
407        if self.key:
408            self.key.set_base_name(self.var_name().strip("_") + "_key")
409
410        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
411            for child in self.value:
412                child.set_base_names()
413        if self.cbor:
414            self.cbor.set_base_names()
415        if self.key:
416            self.key.set_base_names()
417
418    def id(self, with_prefix=True):
419        """Add uniqueness to the base name."""
420        raw_name = self.get_base_name()
421        if not with_prefix and self.short_names:
422            return raw_name
423        if (self.id_prefix
424                and (f"{self.id_prefix}_" not in raw_name)
425                and (self.id_prefix != raw_name.strip("_"))):
426            return f"{self.id_prefix}_{raw_name}"
427        if (self.base_stem
428                and (f"{self.base_stem}_" not in raw_name)
429                and (self.base_stem != raw_name.strip("_"))):
430            return f"{self.base_stem}_{raw_name}"
431        return raw_name
432
433    def init_args(self):
434        """Return the args that should be used to initialize a new instance of this class."""
435        return (self.default_max_qty,)
436
437    def init_kwargs(self):
438        """Return the kwargs that should be used to initialize a new instance of this class."""
439        return {
440            "my_types": self.my_types, "my_control_groups": self.my_control_groups,
441            "short_names": self.short_names}
442
443    def set_id_prefix(self, id_prefix=''):
444        self.id_prefix = id_prefix
445        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
446            for child in self.value:
447                if child.single_func_impl_condition():
448                    child.set_id_prefix(self.generate_base_name())
449                else:
450                    child.set_id_prefix(self.child_base_id())
451        if self.cbor:
452            self.cbor.set_id_prefix(self.child_base_id())
453        if self.key:
454            self.key.set_id_prefix(self.child_base_id())
455
456    def child_base_id(self):
457        """Id to pass to children for them to use as basis for their id/base name."""
458        return self.id()
459
460    def mrepr(self, newline):
461        """Human readable representation."""
462        reprstr = ''
463        if self.quantifier:
464            reprstr += self.quantifier
465        if self.label:
466            reprstr += self.label + ':'
467        for tag in self.tags:
468            reprstr += f"#6.{tag}"
469        if self.key:
470            reprstr += repr(self.key) + " => "
471        if self.is_unambiguous():
472            reprstr += '/'
473        if self.is_unambiguous_repeated():
474            reprstr += '/'
475        reprstr += self.type
476        if self.size:
477            reprstr += '(%d)' % self.size
478        if newline:
479            reprstr += '\n'
480        if self.value:
481            reprstr += pformat(self.value, indent=4, width=1)
482        if self.cbor:
483            reprstr += " cbor: " + repr(self.cbor)
484        return reprstr.replace('\n', '\n    ')
485
486    def _flatten(self):
487        """Recursively flatten children, key, and cbor elements."""
488        new_value = []
489        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
490            for child in self.value:
491                new_value.extend(
492                    child.flatten(allow_multi=self.type != "UNION"))
493            self.value = new_value
494        if self.key:
495            self.key = self.key.flatten()[0]
496        if self.cbor:
497            self.cbor = self.cbor.flatten()[0]
498
499    def flatten(self, allow_multi=False):
500        """Remove unneccessary abstractions, like single-element groups or unions."""
501        self._flatten()
502        if self.type == "OTHER" and self.is_socket and self.value not in self.my_types:
503            return []
504        if self.type in ["GROUP", "UNION"]\
505                and (len(self.value) == 1)\
506                and (not (self.key and self.value[0].key)):
507            self.value[0].min_qty *= self.min_qty
508            self.value[0].max_qty *= self.max_qty
509            if not self.value[0].label:
510                self.value[0].label = self.label
511            if not self.value[0].key:
512                self.value[0].key = self.key
513            self.value[0].tags.extend(self.tags)
514            return self.value
515        elif allow_multi and self.type in ["GROUP"] and self.min_qty == 1 and self.max_qty == 1:
516            return self.value
517        else:
518            return [self]
519
520    def set_min_value(self, min_value):
521        self.min_value = min_value
522
523    def set_max_value(self, max_value):
524        self.max_value = max_value
525
526    def type_and_value(self, new_type, value_generator):
527        """Set the self.type and self.value of this element."""
528        if self.type is not None:
529            raise TypeError(
530                "Cannot have two values: %s, %s" %
531                (self.type, new_type))
532        if new_type is None:
533            raise TypeError("Cannot set None as type")
534        if new_type == "UNION" and self.value is not None:
535            raise ValueError("Did not expect multiple parsed values for union")
536
537        self.type = new_type
538        self.set_value(value_generator)
539
540    def set_value(self, value_generator):
541        """Set the value of this element.
542
543        value_generator must be a function that returns the value of the element."""
544        value = value_generator()
545        self.value = value
546
547        if self.type == "OTHER" and self.value.startswith("$"):
548            self.value = self.value.lstrip("$")
549            self.is_socket = True
550
551        if self.type in ["BSTR", "TSTR"]:
552            if value is not None:
553                self.set_size(len(value))
554        if self.type in ["UINT", "NINT"]:
555            if value is not None:
556                self.size = sizeof(value)
557                self.set_min_value(value)
558                self.set_max_value(value)
559        if self.type == "NINT":
560            self.max_value = -1
561
562    def type_and_range(self, new_type, min_val, max_val, inc_end=True):
563        """Set the self.type and self.minValue and self.max_value (or self.min_size and
564        self.max_size depending on the type) of this element. For use during CDDL parsing.
565        """
566        if not inc_end:
567            max_val -= 1
568        if new_type not in ["INT", "UINT", "NINT"]:
569            raise TypeError(
570                "Only integers (not %s) can have range" %
571                (new_type,))
572        if min_val > max_val:
573            raise TypeError(
574                "Range has larger minimum than maximum (min %d, max %d)" %
575                (min_val, max_val))
576        if min_val == max_val:
577            return self.type_and_value(new_type, min_val)
578        self.type = new_type
579        self.set_min_value(min_val)
580        self.set_max_value(max_val)
581        if new_type in "UINT":
582            self.set_size_range(sizeof(min_val), sizeof(max_val))
583        if new_type == "NINT":
584            self.set_size_range(sizeof(abs(max_val)), sizeof(abs(min_val)))
585        if new_type == "INT":
586            self.set_size_range(None, max(sizeof(abs(max_val)), sizeof(abs(min_val))))
587
588    def type_value_size(self, new_type, value, size):
589        """Set the self.value and self.size of this element."""
590        self.type_and_value(new_type, value)
591        self.set_size(size)
592
593    def type_value_size_range(self, new_type, value, min_size, max_size):
594        """Set the self.value and self.min_size and self.max_size of this element."""
595        self.type_and_value(new_type, value)
596        self.set_size_range(min_size, max_size)
597
598    def set_label(self, label):
599        """Set the self.label of this element. For use during CDDL parsing."""
600        if self.type is not None:
601            raise TypeError("Cannot have label after value: " + label)
602        self.label = label
603
604    def set_quantifier(self, quantifier):
605        """Set the self.quantifier, self.min_qty, and self.max_qty of this element"""
606        if self.type is not None:
607            raise TypeError(
608                "Cannot have quantifier after value: " + quantifier)
609
610        quantifier_mapping = [
611            (r"\?", lambda mo: (0, 1)),
612            (r"\*", lambda mo: (0, None)),
613            (r"\+", lambda mo: (1, None)),
614            (r"(.*?)\*\*?(.*)",
615                lambda mo: (int(mo.groups()[0] or "0", 0), int(mo.groups()[1] or "0", 0) or None)),
616        ]
617
618        self.quantifier = quantifier
619        for (reg, handler) in quantifier_mapping:
620            match_obj = getrp(reg).match(quantifier)
621            if match_obj:
622                (self.min_qty, self.max_qty) = handler(match_obj)
623                if self.max_qty is None:
624                    self.max_qty = self.default_max_qty
625                return
626        raise ValueError("invalid quantifier: %s" % quantifier)
627
628    def set_size(self, size):
629        """Set the self.size of this element.
630
631        This will also set the self.minValue and self.max_value of UINT types.
632        """
633        if self.type is None:
634            raise TypeError("Cannot have size before value: " + str(size))
635        elif self.type in ["INT", "UINT", "NINT"]:
636            value = 256**size
637            if self.type == "INT":
638                self.max_value = int((value >> 1) - 1)
639            if self.type == "UINT":
640                self.max_value = int(value - 1)
641            if self.type in ["INT", "NINT"]:
642                self.min_value = int(-1 * (value >> 1))
643        elif self.type in ["BSTR", "TSTR", "FLOAT"]:
644            self.set_size_range(size, size)
645        else:
646            raise TypeError(".size cannot be applied to %s" % self.type)
647
648    def set_size_range(self, min_size, max_size_in, inc_end=True):
649        """Set the self.minValue and self.max_value or self.min_size and self.max_size of this
650        element based on what values can be contained within an integer of a certain size.
651        """
652        max_size = max_size_in if inc_end else max_size_in - 1
653
654        if (min_size and min_size < 0 or max_size and max_size < 0) \
655           or (None not in [min_size, max_size] and min_size > max_size):
656            raise TypeError(
657                "Invalid size range (min %d, max %d)" %
658                (min_size, max_size))
659
660        self.set_min_size(min_size)
661        self.set_max_size(max_size)
662
663    def set_min_size(self, min_size):
664        """Set self.min_size, and self.minValue if type is UINT."""
665        if self.type == "UINT":
666            self.minValue = 256**min(0, abs(min_size - 1)) if min_size is not None else None
667        self.min_size = min_size if min_size is not None else None
668
669    def set_max_size(self, max_size):
670        """Set self.max_size, and self.max_value if type is UINT."""
671        if self.type == "UINT" and max_size and self.max_value is None:
672            if max_size > 8:
673                raise TypeError(
674                    "Size too large for integer. size %d" %
675                    max_size)
676            self.max_value = 256**max_size - 1
677        self.max_size = max_size
678
679    def set_cbor(self, cbor, cborseq):
680        """Set the self.cbor of this element. For use during CDDL parsing."""
681        if self.type != "BSTR":
682            raise TypeError(
683                "%s must be used with bstr." %
684                (".cborseq" if cborseq else ".cbor",))
685        self.cbor = cbor
686        if cborseq:
687            self.cbor.max_qty = self.default_max_qty
688
689    def set_bits(self, bits):
690        """Set the self.bits of this element. For use during CDDL parsing."""
691        if self.type != "UINT":
692            raise TypeError(".bits must be used with bstr.")
693        self.bits = bits
694
695    def set_key(self, key):
696        """Set the self.key of this element. For use during CDDL parsing."""
697        if self.key is not None:
698            raise TypeError("Cannot have two keys: " + key)
699        if key.type == "GROUP":
700            raise TypeError("A key cannot be a group because it might represent more than 1 type.")
701        self.key = key
702
703    def set_key_or_label(self, key_or_label):
704        """Set the self.label OR self.key of this element.
705
706        In the CDDL "foo: bar", foo can be either a label or a key depending on whether it is in a
707        map. This code uses a slightly different method for choosing between label and key.
708        If the string is recognized as a type, it is treated as a key. For use during CDDL parsing.
709        """
710        if key_or_label in self.my_types:
711            self.set_key(self.parse(key_or_label)[0])
712            assert self.key.type == "OTHER", "This should only be able to produce an OTHER key."
713            if self.label is None:
714                self.set_label(key_or_label)
715        else:
716            self.set_label(key_or_label)
717
718    def add_tag(self, tag):
719        self.tags.append(int(tag))
720
721    def union_add_value(self, value, doubleslash=False):
722        """Append to the self.value of this element.
723
724        Used with the "UNION" type, which has a python list as self.value. The list represents the
725        "children" of the type. For use during CDDL parsing.
726        """
727        if self.type != "UNION":
728            convert_val = copy(self)
729            self.__init__(*self.init_args(), **self.init_kwargs())
730            self.type_and_value("UNION", lambda: [convert_val])
731
732            self.base_name = convert_val.base_name
733            convert_val.base_name = None
734            self.base_stem = convert_val.base_stem
735
736            if not doubleslash:
737                self.label = convert_val.label
738                self.key = convert_val.key
739                self.quantifier = convert_val.quantifier
740                self.max_qty = convert_val.max_qty
741                self.min_qty = convert_val.min_qty
742
743                convert_val.label = None
744                convert_val.key = None
745                convert_val.quantifier = None
746                convert_val.max_qty = 1
747                convert_val.min_qty = 1
748        self.value.append(value)
749
750    def convert_to_key(self):
751        """The current element is the key, so copy it to a new element and set the key to the new"""
752        convert_val = copy(self)
753        self.__init__(*self.init_args(), **self.init_kwargs())
754        self.set_key(convert_val)
755
756        self.label = convert_val.label
757        self.quantifier = convert_val.quantifier
758        self.max_qty = convert_val.max_qty
759        self.min_qty = convert_val.min_qty
760        self.base_name = convert_val.base_name
761        self.base_stem = convert_val.base_stem
762
763        convert_val.label = None
764        convert_val.quantifier = None
765        convert_val.max_qty = 1
766        convert_val.min_qty = 1
767        convert_val.base_name = None
768
769    # A dict with lists of regexes and their corresponding handlers.
770    # This is a dict in case multiple inheritors of CddlParser are used at once, in which case
771    # they may have slightly different handlers.
772    cddl_regexes = dict()
773
774    def cddl_regexes_init(self):
775        """Initialize the cddl_regexes dict"""
776        match_uint = r"(0x[0-9a-fA-F]+|0o[0-7]+|0b[01]+|\d+)"
777        match_int = r"(-?" + match_uint + ")"
778        match_nint = r"(-" + match_uint + ")"
779
780        self_type = type(self)
781
782        # The "range_types" match the contents of brackets i.e. (), [], and {},
783        # and strings, i.e. ' or "
784        range_types = [
785            (r'(?P<bracket>\[(?P<item>(?>[^[\]]+|(?&bracket))*)\])',
786             lambda m_self, list_str: m_self.type_and_value(
787                 "LIST", lambda: m_self.parse(list_str))),
788            (r'(?P<paren>\((?P<item>(?>[^\(\)]+|(?&paren))*)\))',
789             lambda m_self, group_str: m_self.type_and_value(
790                 "GROUP", lambda: m_self.parse(group_str))),
791            (r'(?P<curly>{(?P<item>(?>[^{}]+|(?&curly))*)})',
792             lambda m_self, map_str: m_self.type_and_value(
793                 "MAP", lambda: m_self.parse(map_str))),
794            (r'\'(?P<item>.*?)(?<!\\)\'',
795             lambda m_self, string: m_self.type_and_value("BSTR", lambda: string)),
796            (r'\"(?P<item>.*?)(?<!\\)\"',
797             lambda m_self, string: m_self.type_and_value("TSTR", lambda: string)),
798        ]
799        range_types_regex = '|'.join([regex for (regex, _) in range_types])
800        for i in range(range_types_regex.count("item")):
801            range_types_regex = range_types_regex.replace("item", "it%dem" % i, 1)
802
803        # The following regexes match different parts of the element. The order of the list is
804        # important because it implements the operator precendence defined in the CDDL spec.
805        # The range_types are separate because they are reused in one of the other regexes.
806        self_type.cddl_regexes[self_type] = range_types + [
807            (r'\/\/\s*(?P<item>.+?)(?=\/\/|\Z)',
808             lambda m_self, union_str: m_self.union_add_value(
809                 m_self.parse("(%s)" % union_str if ',' in union_str else union_str)[0],
810                 doubleslash=True)),
811            (r'(?P<item>[^\W\d][\w-]*)\s*:',
812             self_type.set_key_or_label),
813            (r'((\=\>)|:)',
814             lambda m_self, _: m_self.convert_to_key()),
815            (r'([+*?])',
816             self_type.set_quantifier),
817            (r'(' + match_uint + r'\*\*?' + match_uint + r'?)',
818             self_type.set_quantifier),
819            (r'\/\s*(?P<item>((' + range_types_regex + r')|[^,\[\]{}()])+?)(?=\/|\Z|,)',
820             lambda m_self, union_str: m_self.union_add_value(
821                 m_self.parse(union_str)[0])),
822            (r'(uint|nint|int|float|bstr|tstr|bool|nil|any)(?![\w-])',
823             lambda m_self, type_str: m_self.type_and_value(type_str.upper(), lambda: None)),
824            (r'undefined(?!\w)',
825             lambda m_self, _: m_self.type_and_value("UNDEF", lambda: None)),
826            (r'float16(?![\w-])',
827             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 2)),
828            (r'float16-32(?![\w-])',
829             lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 2, 4)),
830            (r'float32(?![\w-])',
831             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 4)),
832            (r'float32-64(?![\w-])',
833             lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 4, 8)),
834            (r'float64(?![\w-])',
835             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 8)),
836            (r'\-?\d*\.\d+',
837             lambda m_self, num: m_self.type_and_value("FLOAT", lambda: float(num))),
838            (match_uint + r'\.\.' + match_uint,
839             lambda m_self, _range: m_self.type_and_range(
840                 "UINT", *map(lambda num: int(num, 0), _range.split("..")))),
841            (match_nint + r'\.\.' + match_uint,
842             lambda m_self, _range: m_self.type_and_range(
843                 "INT", *map(lambda num: int(num, 0), _range.split("..")))),
844            (match_nint + r'\.\.' + match_nint,
845             lambda m_self, _range: m_self.type_and_range(
846                 "NINT", *map(lambda num: int(num, 0), _range.split("..")))),
847            (match_uint + r'\.\.\.' + match_uint,
848             lambda m_self, _range: m_self.type_and_range(
849                 "UINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
850            (match_nint + r'\.\.\.' + match_uint,
851             lambda m_self, _range: m_self.type_and_range(
852                 "INT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
853            (match_nint + r'\.\.\.' + match_nint,
854             lambda m_self, _range: m_self.type_and_range(
855                 "NINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
856            (match_nint,
857             lambda m_self, num: m_self.type_and_value("NINT", lambda: int(num, 0))),
858            (match_uint,
859             lambda m_self, num: m_self.type_and_value("UINT", lambda: int(num, 0))),
860            (r'true(?!\w)',
861             lambda m_self, _: m_self.type_and_value("BOOL", lambda: True)),
862            (r'false(?!\w)',
863             lambda m_self, _: m_self.type_and_value("BOOL", lambda: False)),
864            (r'#6\.(?P<item>\d+)',
865             self_type.add_tag),
866            (r'(\$?\$?[\w-]+)',
867             lambda m_self, other_str: m_self.type_and_value("OTHER", lambda: other_str)),
868            (r'\.size \(?(?P<item>' + match_int + r'\.\.' + match_int + r')\)?',
869             lambda m_self, _range: m_self.set_size_range(
870                 *map(lambda num: int(num, 0), _range.split("..")))),
871            (r'\.size \(?(?P<item>' + match_int + r'\.\.\.' + match_int + r')\)?',
872             lambda m_self, _range: m_self.set_size_range(
873                 *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
874            (r'\.size \(?(?P<item>' + match_uint + r')\)?',
875             lambda m_self, size: m_self.set_size(int(size, 0))),
876            (r'\.gt \(?(?P<item>' + match_int + r')\)?',
877             lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0) + 1)),
878            (r'\.lt \(?(?P<item>' + match_int + r')\)?',
879             lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0) - 1)),
880            (r'\.ge \(?(?P<item>' + match_int + r')\)?',
881             lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0))),
882            (r'\.le \(?(?P<item>' + match_int + r')\)?',
883             lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0))),
884            (r'\.eq \(?(?P<item>' + match_int + r')\)?',
885             lambda m_self, value: m_self.set_value(lambda: int(value, 0))),
886            (r'\.eq \"(?P<item>.*?)(?<!\\)\"',
887             lambda m_self, value: m_self.set_value(lambda: value)),
888            (r'\.cbor (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
889             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)),
890            (r'\.cbor (?P<item>[^\s,]+)',
891             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)),
892            (r'\.cborseq (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
893             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)),
894            (r'\.cborseq (?P<item>[^\s,]+)',
895             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)),
896            (r'\.bits (?P<item>[\w-]+)',
897             lambda m_self, bits_str: m_self.set_bits(bits_str))
898        ]
899
900    def get_value(self, instr):
901        """Parse from the beginning of instr (string) until a full element has been parsed.
902
903        self will become that element. This function is recursive, so if a nested element
904        ("MAP"/"LIST"/"UNION"/"GROUP") is encountered, this function will create new instances and
905        add them to self.value as a list. Likewise, if a key or cbor definition is encountered, a
906        new element will be created and assigned to self.key or self.cbor. When new elements are
907        created, get_value() is called on those elements, via parse().
908        """
909        types = type(self).cddl_regexes[type(self)]
910
911        # Keep parsing until a comma, or to the end of the string.
912        while instr != '' and instr[0] != ',':
913            match_obj = None
914            for (reg, handler) in types:
915                match_obj = getrp(reg).match(instr)
916                if match_obj:
917                    try:
918                        match_str = match_obj.group("item")
919                    except IndexError:
920                        match_str = match_obj.group(0)
921                    try:
922                        handler(self, match_str)
923                    except Exception as e:
924                        raise Exception("Failed while parsing this: '%s'" % match_str) from e
925                    self.match_str += match_str
926                    old_len = len(instr)
927                    instr = getrp(reg).sub('', instr, count=1).lstrip()
928                    if old_len == len(instr):
929                        raise Exception("empty match")
930                    break
931
932            if not match_obj:
933                raise TypeError("Could not parse this: '%s'" % instr)
934
935        instr = instr[1:]
936        if not self.type:
937            raise ValueError("No proper value while parsing: %s" % instr)
938
939        # Return the unparsed part of the string.
940        return instr.strip()
941
942    def elem_has_key(self):
943        """For checking whether this element has a key (i.e. that it is a valid "MAP" child)
944
945        This must have some recursion since CDDL allows the key to be hidden
946        behind layers of indirection.
947        """
948        return self.key is not None\
949            or (self.type == "OTHER" and self.my_types[self.value].elem_has_key())\
950            or (self.type in ["GROUP", "UNION"]
951                and (self.value and all(child.elem_has_key() for child in self.value)))
952
953    def post_validate(self):
954        """Function for performing validations that must be done after all parsing is complete.
955
956        This is recursive, so it will post_validate all its children + key + cbor.
957        """
958        # Validation of this element.
959        if self.type in ["LIST", "MAP"]:
960            none_keys = [child for child in self.value if not child.elem_has_key()]
961            child_keys = [child for child in self.value if child not in none_keys]
962            if self.type == "MAP" and none_keys:
963                raise TypeError(
964                    "Map member(s) must have key: " + str(none_keys) + " pointing to "
965                    + str(
966                        [self.my_types[elem.value] for elem in none_keys
967                            if elem.type == "OTHER"]))
968            if self.type == "LIST" and child_keys:
969                raise TypeError(
970                    str(self) + linesep
971                    + "List member(s) cannot have key: " + str(child_keys) + " pointing to "
972                    + str(
973                        [self.my_types[elem.value] for elem in child_keys
974                            if elem.type == "OTHER"]))
975        if self.type == "OTHER":
976            if self.value not in self.my_types.keys() or not isinstance(
977                    self.my_types[self.value], type(self)):
978                raise TypeError("%s has not been parsed." % self.value)
979        if self.type == "LIST":
980            for child in self.value[:-1]:
981                if child.type == "ANY":
982                    if child.min_qty != child.max_qty:
983                        raise TypeError(f"ambiguous quantity of 'any' is not supported in list, "
984                                        + "except as last element:\n{str(child)}")
985        if self.type == "UNION" and len(self.value) > 1:
986            if any(((not child.key and child.type == "ANY") or (
987                    child.key and child.key.type == "ANY")) for child in self.value):
988                raise TypeError(
989                    "'any' inside union is not supported since it would always be triggered.")
990
991        # Validation of child elements.
992        if self.type in ["MAP", "LIST", "UNION", "GROUP"]:
993            for child in self.value:
994                child.post_validate()
995        if self.key:
996            self.key.post_validate()
997        if self.cbor:
998            self.cbor.post_validate()
999
1000    def post_validate_control_group(self):
1001        if self.type != "GROUP":
1002            raise TypeError("control groups must be of GROUP type.")
1003        for c in self.value:
1004            if c.type != "UINT" or c.value is None or c.value < 0:
1005                raise TypeError("control group members must be literal positive integers.")
1006
1007    def parse(self, instr):
1008        """Parses entire instr and returns a list of instances."""
1009        instr = instr.strip()
1010        values = []
1011        while instr != '':
1012            value = type(self)(*self.init_args(), **self.init_kwargs(), base_stem=self.base_stem)
1013            instr = value.get_value(instr)
1014            values.append(value)
1015        return values
1016
1017    def __repr__(self):
1018        return self.mrepr(False)
1019
1020
1021c_keywords = [
1022    "alignas", "alignof", "atomic_bool", "atomic_int", "auto", "bool", "break", "case", "char",
1023    "complex", "const", "constexpr", "continue", "default", "do", "double", "else", "enum",
1024    "extern", "false", "float", "for", "goto", "if", "imaginary", "inline", "int", "long",
1025    "noreturn", "nullptr", "register", "restrict", "return", "short", "signed", "sizeof", "static",
1026    "static_assert", "struct", "switch", "thread_local", "true", "typedef", "typeof",
1027    "typeof_unqual", "union", "unsigned", "void", "volatile", "while"]
1028
1029
1030c_keywords_underscore = [
1031    "_Alignas", "_Alignof", "_Atomic", "_BitInt", "_Bool", "_Complex", "_Decimal128", "_Decimal32",
1032    "_Decimal64", "_Generic", "_Imaginary", "_Noreturn", "_Pragma", "_Static_assert",
1033    "_Thread_local"]
1034
1035
1036class CddlXcoder(CddlParser):
1037
1038    def __init__(self, *args, **kwargs):
1039        super(CddlXcoder, self).__init__(*args, **kwargs)
1040
1041        # The prefix used for C code accessing this element, i.e. the struct
1042        # hierarchy leading up to this element.
1043        self.accessPrefix = None
1044        self.is_delegated = False
1045        # Used as a guard against endless recursion in self.dependsOn()
1046        self.dependsOnCall = False
1047        self.skipped = False
1048
1049    def var_name(self, with_prefix=False, observe_skipped=True):
1050        """Name of variables and enum members for this element."""
1051        if (observe_skipped and self.skip_condition()
1052                and self.type in ["LIST", "MAP", "GROUP"] and self.value):
1053            return self.value[0].var_name(with_prefix)
1054        name = self.id(with_prefix=with_prefix)
1055        if name in c_keywords:
1056            name = name.capitalize()
1057        elif name in c_keywords_underscore:
1058            name = "_" + name
1059        return name
1060
1061    def skip_condition(self):
1062        """Whether this element should have its result variable omitted."""
1063        if self.skipped:
1064            return True
1065        if self.type in ["LIST", "MAP", "GROUP"]:
1066            return not self.repeated_multi_var_condition()
1067        if self.type == "OTHER":
1068            return ((not self.repeated_multi_var_condition())
1069                    and (not self.multi_var_condition())
1070                    and (self.single_func_impl_condition() or self in self.my_types.values()))
1071        return False
1072
1073    def set_skipped(self, skipped):
1074        if self.range_check_condition() \
1075           and self.repeated_single_func_impl_condition() \
1076           and not self.key:
1077            self.skipped = True
1078        else:
1079            self.skipped = skipped
1080        return
1081
1082    def delegate_type_condition(self):
1083        """Whether to use the C type of the first child as this type's C type"""
1084        ret = self.type in ["LIST", "MAP", "GROUP"]
1085        return ret
1086
1087    def is_delegated_type(self):
1088        return self.is_delegated
1089
1090    def set_access_prefix(self, prefix, is_delegated=False):
1091        """Recursively set the access prefix for this element and all its children."""
1092        self.accessPrefix = prefix
1093        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
1094            self.set_skipped(self.skip_condition())
1095            list(map(lambda child: child.set_skipped(child.skip_condition()),
1096                     self.value))
1097            list(map(lambda child: child.set_access_prefix(
1098                     self.var_access(),
1099                     is_delegated=(self.delegate_type_condition()
1100                                   or (is_delegated and self.skip_condition()))),
1101                     self.value))
1102        elif self in self.my_types.values():
1103            self.set_skipped(not self.multi_member())
1104        if self.key is not None:
1105            self.key.set_access_prefix(self.var_access())
1106        if self.cbor_var_condition():
1107            self.cbor.set_access_prefix(self.var_access())
1108        self.is_delegated = is_delegated and not self.skip_condition()
1109        return
1110
1111    def multi_member(self):
1112        """Whether this type has multiple member variables."""
1113        return self.multi_var_condition() or self.repeated_multi_var_condition()
1114
1115    def is_unambiguous_value(self):
1116        """Whether this element is a non-compound value that can be known a priori."""
1117        return (self.type in ["NIL", "UNDEF", "ANY"]
1118                or (self.type in ["INT", "NINT", "UINT", "FLOAT", "BSTR", "TSTR", "BOOL"]
1119                    and self.value is not None)
1120                or (self.type == "OTHER" and self.my_types[self.value].is_unambiguous()))
1121
1122    def is_unambiguous_repeated(self):
1123        """Whether the repeated part of this element is known a priori."""
1124        return (self.is_unambiguous_value()
1125                and (self.key is None or self.key.is_unambiguous_repeated())
1126                or (self.type in ["LIST", "GROUP", "MAP"] and len(self.value) == 0)
1127                or (self.type in ["LIST", "GROUP", "MAP"]
1128                    and all((child.is_unambiguous() for child in self.value))))
1129
1130    def is_unambiguous(self):
1131        """Whether or not we can know the exact encoding of this element a priori."""
1132        return (self.is_unambiguous_repeated() and (self.min_qty == self.max_qty))
1133
1134    def access_append_delimiter(self, prefix, delimiter, *suffix):
1135        """Create an access prefix based on an existing prefix, delimiter and a
1136        suffix.
1137        """
1138        assert prefix is not None, "No access prefix for %s" % self.var_name()
1139        return delimiter.join((prefix,) + suffix)
1140
1141    def access_append(self, *suffix):
1142        """Create an access prefix from this element's prefix, delimiter and a
1143        provided suffix.
1144        """
1145        suffix = list(suffix)
1146        return self.access_append_delimiter(self.accessPrefix, '.', *suffix)
1147
1148    def var_access(self):
1149        """"Path" to this element's variable."""
1150        if self.is_unambiguous():
1151            return "NULL"
1152        return self.access_append()
1153
1154    def val_access(self):
1155        """"Path" to access this element's actual value variable."""
1156        if self.is_unambiguous_repeated():
1157            ret = "NULL"
1158        elif self.skip_condition() or self.is_delegated_type():
1159            ret = self.var_access()
1160        else:
1161            ret = self.access_append(self.var_name())
1162        return ret
1163
1164    def repeated_val_access(self):
1165        if self.is_unambiguous_repeated():
1166            return "NULL"
1167        return self.access_append(self.var_name())
1168
1169    def present_var_condition(self):
1170        """Whether to include a "present" variable for this element."""
1171        return self.min_qty == 0 and isinstance(self.max_qty, int) and self.max_qty <= 1
1172
1173    def count_var_condition(self):
1174        """Whether to include a "count" variable for this element."""
1175        return isinstance(self.max_qty, str) or self.max_qty > 1
1176
1177    def is_cbor(self):
1178        """Whether to include a "cbor" variable for this element."""
1179        return (self.type not in ["NIL", "UNDEF", "ANY"]) \
1180            and ((self.type != "OTHER") or (self.my_types[self.value].is_cbor()))
1181
1182    def cbor_var_condition(self):
1183        """Whether to include a "cbor" variable for this element."""
1184        return (self.cbor is not None) and self.cbor.is_cbor()
1185
1186    def choice_var_condition(self):
1187        """Whether to include a "choice" variable for this element."""
1188        return self.type == "UNION"
1189
1190    def reduced_key_var_condition(self):
1191        """Whether this specific type is a key."""
1192        if self.key is not None:
1193            return True
1194        return False
1195
1196    def key_var_condition(self):
1197        """Whether to include a "key" variable for this element."""
1198        if self.reduced_key_var_condition():
1199            return True
1200        if self.type == "OTHER" and self.my_types[self.value].key_var_condition():
1201            return True
1202        if (self.type in ["GROUP", "UNION"]
1203                and len(self.value) >= 1
1204                and self.value[0].reduced_key_var_condition()):
1205            return True
1206        return False
1207
1208    def self_repeated_multi_var_condition(self):
1209        """Whether this value adds any repeated elements by itself. I.e. excluding
1210        multiple elements from children.
1211        """
1212        return (self.key_var_condition()
1213                or self.cbor_var_condition()
1214                or self.choice_var_condition())
1215
1216    def multi_val_condition(self):
1217        """Whether this element's actual value has multiple members."""
1218        return (
1219            self.type in ["LIST", "MAP", "GROUP", "UNION"]
1220            and (len(self.value) > 1 or (len(self.value) == 1 and self.value[0].multi_member())))
1221
1222    def repeated_multi_var_condition(self):
1223        """Whether any extra variables are to be included for this element for each
1224        repetition.
1225        """
1226        return self.self_repeated_multi_var_condition() or self.multi_val_condition()
1227
1228    def multi_var_condition(self):
1229        """Whether any extra variables are to be included for this element outside
1230        of repetitions.
1231        Also, whether this element must involve a call to multi_xcode(), i.e. unless
1232        it's repeated exactly once.
1233        """
1234        return self.present_var_condition() or self.count_var_condition()
1235
1236    def range_check_condition(self):
1237        """Whether this element needs a check (memcmp) for a string value."""
1238        if self.type == "OTHER":
1239            return self.my_types[self.value].range_check_condition()
1240        if self.type not in ["INT", "NINT", "UINT", "BSTR", "TSTR"]:
1241            return False
1242        if self.value is not None:
1243            return False
1244        if self.type in ["INT", "NINT", "UINT"] \
1245                and (self.min_value is not None or self.max_value is not None):
1246            return True
1247        if self.type == "UINT" and self.bits:
1248            return True
1249        if self.type in ["BSTR", "TSTR"] \
1250                and (self.min_size is not None or self.max_size is not None):
1251            return True
1252        return False
1253
1254    def type_def_condition(self):
1255        """Whether this element should have a typedef in the code."""
1256        if self in self.my_types.values() and self.multi_member() and not self.is_unambiguous():
1257            return True
1258        return False
1259
1260    def repeated_type_def_condition(self):
1261        """Whether this type needs a typedef for its repeated part."""
1262        return (
1263            self.repeated_multi_var_condition()
1264            and self.multi_var_condition()
1265            and not self.is_unambiguous_repeated())
1266
1267    def single_func_impl_condition(self):
1268        """Whether this element needs its own encoder/decoder function."""
1269        return (
1270            False
1271            or self.reduced_key_var_condition()
1272            or self.cbor_var_condition()
1273            or (self.tags and self in self.my_types.values())
1274            or self.type_def_condition()
1275            or (self.type in ["LIST", "MAP"])
1276            or (self.type == "GROUP" and len(self.value) != 0))
1277
1278    def repeated_single_func_impl_condition(self):
1279        """Whether this element needs its own encoder/decoder function."""
1280        return self.repeated_type_def_condition() \
1281            or (self.type in ["LIST", "MAP", "GROUP"] and self.multi_member()) \
1282            or (
1283                self.multi_var_condition()
1284                and (self.self_repeated_multi_var_condition() or self.range_check_condition()))
1285
1286    def int_val(self):
1287        """If this element is an integer, or starts with an integer, return the integer value."""
1288        if self.key:
1289            return self.key.int_val()
1290        elif self.type in ("UINT", "NINT") and self.is_unambiguous():
1291            return self.value
1292        elif self.type == "GROUP" and not self.count_var_condition():
1293            return self.value[0].int_val()
1294        elif self.type == "OTHER" \
1295                and not self.count_var_condition() \
1296                and not self.single_func_impl_condition() \
1297                and not self.my_types[self.value].single_func_impl_condition():
1298            return self.my_types[self.value].int_val()
1299        return None
1300
1301    def is_int_disambiguated(self):
1302        """Whether this element starts with a specific integer that can be used to immediately
1303        disambiguate it from other elements.
1304        """
1305        return self.int_val() is not None
1306
1307    def all_children_disambiguated(self, min_val, max_val):
1308        """Whether all children of this element can be disambiguated via a starting integer.
1309
1310        This is relevant because it allows the decoder to directly decode the integer into an enum
1311        value.
1312        The min_val and max_val are to check whether the integers are within a certain range.
1313        """
1314        values = set(child.int_val() for child in self.value)
1315        retval = (len(values) == len(self.value)) and None not in values \
1316            and max(values) <= max_val and min(values) >= min_val
1317        return retval
1318
1319    def all_children_int_disambiguated(self):
1320        """See all_children_disambiguated()"""
1321        return self.all_children_disambiguated(INT32_MIN, INT32_MAX)
1322
1323    def all_children_uint_disambiguated(self):
1324        """See all_children_disambiguated()"""
1325        return self.all_children_disambiguated(0, INT32_MAX)
1326
1327    def present_var_name(self):
1328        """Name of the "present" variable for this element."""
1329        return "%s_present" % (self.var_name())
1330
1331    def present_var_access(self):
1332        """Full "path" of the "present" variable for this element."""
1333        return self.access_append(self.present_var_name())
1334
1335    def count_var_name(self):
1336        """Name of the "count" variable for this element."""
1337        return "%s_count" % (self.var_name())
1338
1339    def count_var_access(self):
1340        """Full "path" of the "count" variable for this element."""
1341        return self.access_append(self.count_var_name())
1342
1343    def choice_var_name(self):
1344        """Name of the "choice" variable for this element."""
1345        return self.var_name() + "_choice"
1346
1347    def enum_var_name(self):
1348        """Name of the enum entry for this element."""
1349        return self.var_name(with_prefix=True) + "_c"
1350
1351    def enum_var(self, int_val=False):
1352        """Enum entry for this element."""
1353        return f"{self.enum_var_name()} = {val_to_str(self.int_val())}" \
1354               if int_val else self.enum_var_name()
1355
1356    def choice_var_access(self):
1357        """Full "path" of the "choice" variable for this element."""
1358        return self.access_append(self.choice_var_name())
1359
1360
1361class CddlValidationError(Exception):
1362    pass
1363
1364
1365class KeyTuple(tuple):
1366    """Subclass of tuple for holding key,value pairs.
1367
1368    This is to make it possible to use isinstance() to separate it from other tuples."""
1369    def __new__(cls, *in_tuple):
1370        return super(KeyTuple, cls).__new__(cls, *in_tuple)
1371
1372
1373class DataTranslator(CddlXcoder):
1374    """Convert data between CBOR, JSON and YAML, and validate against the provided CDDL.
1375
1376    Decode and validate CBOR into Python structures to be able to make Python scripts that
1377    manipulate CBOR code.
1378    """
1379
1380    @staticmethod
1381    def format_obj(obj):
1382        """Format a Python object for printing by adding newlines and indentation."""
1383        formatted = pformat(obj)
1384        out_str = ""
1385        indent = 0
1386        new_line = True
1387        for c in formatted:
1388            if new_line:
1389                if c == " ":
1390                    continue
1391                new_line = False
1392            out_str += c
1393            if c in "[(":
1394                indent += 1
1395            if c in ")]" and indent > 0:
1396                indent -= 1
1397            if c in "[(,":
1398                out_str += linesep
1399                out_str += "  " * indent
1400                new_line = True
1401        return out_str
1402
1403    def id(self):
1404        """Override the id() function.
1405
1406        If the name starts with an underscore, prepend an 'f',
1407        since namedtuple() doesn't support identifiers that start with an underscore.
1408        """
1409        return getrp(r"\A_").sub("f_", self.generate_base_name())
1410
1411    def var_name(self):
1412        """Override the var_name()"""
1413        return self.id()
1414
1415    def _decode_assert(self, test, msg=""):
1416        """Check a condition and raise a CddlValidationError if not."""
1417        if not test:
1418            raise CddlValidationError(
1419                f"Data did not decode correctly {'(' + msg + ')' if msg else ''}")
1420
1421    def _check_tag(self, obj):
1422        """Check that no unexpected tags are attached to this data.
1423
1424        Return whether a tag was present.
1425        """
1426        tags = copy(self.tags)  # All expected tags
1427        # Process all tags present in obj
1428        while isinstance(obj, CBORTag):
1429            if obj.tag in tags or self.type == "ANY":
1430                if obj.tag in tags:
1431                    tags.remove(obj.tag)
1432                obj = obj.value
1433                continue
1434            elif self.type in ["OTHER", "GROUP", "UNION"]:
1435                break
1436            self._decode_assert(False, f"Tag ({obj.tag}) not expected for {self}")
1437        # Check that all expected tags were found in obj.
1438        self._decode_assert(not tags, f"Expected tags ({tags}), but none present.")
1439        return obj
1440
1441    def _expected_type(self):
1442        """Return our expected python type as returned by cbor2."""
1443        return {
1444            "UINT": lambda: (int,),
1445            "INT": lambda: (int,),
1446            "NINT": lambda: (int,),
1447            "FLOAT": lambda: (float,),
1448            "TSTR": lambda: (str,),
1449            "BSTR": lambda: (bytes,),
1450            "NIL": lambda: (type(None),),
1451            "UNDEF": lambda: (type(undefined),),
1452            "ANY": lambda: (int, float, str, bytes, type(None), type(undefined), bool, list, dict),
1453            "BOOL": lambda: (bool,),
1454            "LIST": lambda: (tuple, list),
1455            "MAP": lambda: (dict,),
1456        }[self.type]()
1457
1458    def _check_type(self, obj):
1459        """Check that the decoded object has the correct type."""
1460        if self.type not in ["OTHER", "GROUP", "UNION"]:
1461            exp_type = self._expected_type()
1462            self._decode_assert(
1463                type(obj) in exp_type,
1464                f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}")
1465
1466    def _check_value(self, obj):
1467        """Check that the decode value conforms to the restrictions in the CDDL."""
1468        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", "BSTR", "BOOL"] \
1469                and self.value is not None:
1470            value = self.value
1471            if self.type == "BSTR":
1472                value = self.value.encode("utf-8")
1473            self._decode_assert(
1474                self.value == obj,
1475                f"{obj} should have value {self.value} according to {self.var_name()}")
1476        if self.type in ["UINT", "INT", "NINT", "FLOAT"]:
1477            if self.min_value is not None:
1478                self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value))
1479            if self.max_value is not None:
1480                self._decode_assert(obj <= self.max_value, "Maximum value: " + str(self.max_value))
1481        if self.type == "UINT":
1482            if self.bits:
1483                mask = sum(((1 << b.value) for b in self.my_control_groups[self.bits].value))
1484                self._decode_assert(not (obj & ~mask), "Allowed bitmask: " + bin(mask))
1485        if self.type in ["TSTR", "BSTR"]:
1486            if self.min_size is not None:
1487                self._decode_assert(
1488                    len(obj) >= self.min_size, "Minimum length: " + str(self.min_size))
1489            if self.max_size is not None:
1490                self._decode_assert(
1491                    len(obj) <= self.max_size, "Maximum length: " + str(self.max_size))
1492
1493    def _check_key(self, obj):
1494        """Check that the object is not a KeyTuple, which would mean it's not properly processed."""
1495        self._decode_assert(
1496            not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj))
1497
1498    def _flatten_obj(self, obj):
1499        """Recursively remove intermediate objects that have single members. Keep lists as is."""
1500        if isinstance(obj, tuple) and len(obj) == 1:
1501            return self._flatten_obj(obj[0])
1502        return obj
1503
1504    def _flatten_list(self, name, obj):
1505        """Return the contents of a list if it has a single member and the same name as us."""
1506        if (isinstance(obj, list)
1507                and len(obj) == 1
1508                and (isinstance(obj[0], list) or isinstance(obj[0], tuple))
1509                and len(obj[0]) == 1
1510                and hasattr(obj[0], name)):
1511            return [obj[0][0]]
1512        return obj
1513
1514    def _construct_obj(self, my_list):
1515        """Construct a namedtuple object from my_list. my_list contains tuples of name/value.
1516
1517        Also, attempt to flatten redundant levels of abstraction.
1518        """
1519        if my_list == []:
1520            return None
1521        names, values = tuple(zip(*my_list))
1522        if len(values) == 1:
1523            values = (self._flatten_obj(values[0]), )
1524        values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values)))
1525        assert (not any((isinstance(elem, KeyTuple) for elem in values))), \
1526            f"KeyTuple not processed: {values}"
1527        return namedtuple("_", names)(*values)
1528
1529    def _add_if(self, my_list, obj, expect_key=False, name=None):
1530        """Add construct obj and add it to my_list if relevant.
1531
1532        Also, process any KeyTuples present.
1533        """
1534        if expect_key and self.type == "OTHER" and self.key is None:
1535            self.my_types[self.value]._add_if(my_list, obj)
1536            return
1537        if self.is_unambiguous():
1538            return
1539        if isinstance(obj, list):
1540            for i in range(len(obj)):
1541                if isinstance(obj[i], KeyTuple):
1542                    retvals = list()
1543                    self._add_if(retvals, obj[i])
1544                    obj[i] = self._construct_obj(retvals)
1545                if self.type == "BSTR" and self.cbor_var_condition() and isinstance(obj[i], bytes):
1546                    assert all((isinstance(o, bytes) for o in obj)), \
1547                           """Unsupported configuration for cbor bstr. If a list contains a
1548CBOR-formatted bstr, all elements must be bstrs. If not, it is a programmer error."""
1549        if isinstance(obj, KeyTuple):
1550            key, obj = obj
1551            if key is not None:
1552                self.key._add_if(my_list, key, name=self.var_name() + "_key")
1553        if self.type == "BSTR" and self.cbor_var_condition():
1554            # If a bstr is CBOR-formatted, add both the string and the decoding of the string here
1555            if isinstance(obj, list) and all((isinstance(o, bytes) for o in obj)):
1556                # One or more bstr in a list (i.e. it is optional or repeated)
1557                my_list.append((name or self.var_name(), [self.cbor.decode_str(o) for o in obj]))
1558                my_list.append(((name or self.var_name()) + "_bstr", obj))
1559                return
1560            if isinstance(obj, bytes):
1561                my_list.append((name or self.var_name(), self.cbor.decode_str(obj)))
1562                my_list.append(((name or self.var_name()) + "_bstr", obj))
1563                return
1564        my_list.append((name or self.var_name(), obj))
1565
1566    def _iter_is_empty(self, it):
1567        """Throw CddlValidationError if iterator is not empty.
1568
1569        This consumes one element if present.
1570        """
1571        try:
1572            val = next(it)
1573        except StopIteration:
1574            return True
1575        raise CddlValidationError(
1576            f"Iterator not consumed while parsing \n{self}\nRemaining elements:\n elem: "
1577            + "\n elem: ".join(str(elem) for elem in ([val] + list(it))))
1578
1579    def _iter_next(self, it):
1580        """Get next element from iterator, throw CddlValidationError instead of StopIteration."""
1581        try:
1582            next_obj = next(it)
1583            return next_obj
1584        except StopIteration:
1585            raise CddlValidationError("Iterator empty")
1586
1587    def _decode_single_obj(self, obj):
1588        """Decode single CDDL value, excluding repetitions"""
1589        self._check_key(obj)
1590        obj = self._check_tag(obj)
1591        self._check_type(obj)
1592        self._check_value(obj)
1593        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR",
1594                         "BSTR", "BOOL", "NIL", "UNDEF", "ANY"]:
1595            return obj
1596        elif self.type == "OTHER":
1597            return self.my_types[self.value]._decode_single_obj(obj)
1598        elif self.type == "LIST":
1599            retval = list()
1600            child_val = iter(obj)
1601            for child in self.value:
1602                ret = child._decode_full(child_val)
1603                child_val, child_obj = ret
1604                child._add_if(retval, child_obj)
1605            self._iter_is_empty(child_val)
1606            return self._construct_obj(retval)
1607        elif self.type == "MAP":
1608            retval = list()
1609            child_val = iter(KeyTuple(item) for item in obj.items())
1610            for child in self.value:
1611                child_val, child_key_val = child._decode_full(child_val)
1612                child._add_if(retval, child_key_val, expect_key=True)
1613            self._iter_is_empty(child_val)
1614            return self._construct_obj(retval)
1615        elif self.type == "UNION":
1616            retval = list()
1617            for child in self.value:
1618                try:
1619                    child_obj = child._decode_single_obj(obj)
1620                    child._add_if(retval, child_obj)
1621                    retval.append(("union_choice", child.var_name()))
1622                    return self._construct_obj(retval)
1623                except CddlValidationError as c:
1624                    self.errors.append(str(c))
1625            self._decode_assert(False, "No matches for union: " + str(self))
1626        assert False, "Unexpected type: " + self.type
1627
1628    def _handle_key(self, next_obj):
1629        """Decode key and value in the form of a KeyTuple"""
1630        self._decode_assert(
1631            isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj))
1632        key, obj = next_obj
1633        key_res = self.key._decode_single_obj(key)
1634        obj_res = self._decode_single_obj(obj)
1635        res = KeyTuple((key_res if not self.key.is_unambiguous() else None, obj_res))
1636        return res
1637
1638    def _decode_obj(self, it):
1639        """Decode single CDDL value, excluding repetitions.
1640
1641        May consume 0 to n CBOR objects via the iterator.
1642        """
1643        my_list = list()
1644        if self.key is not None:
1645            it, it_copy = tee(it)
1646            key_res = self._handle_key(self._iter_next(it_copy))
1647            return it_copy, key_res
1648        if self.tags:
1649            it, it_copy = tee(it)
1650            maybe_tag = next(it_copy)
1651            if isinstance(maybe_tag, CBORTag):
1652                tag_res = self._decode_single_obj(maybe_tag)
1653                return it_copy, tag_res
1654        if self.type == "OTHER" and self.key is None:
1655            return self.my_types[self.value]._decode_full(it)
1656        elif self.type == "GROUP":
1657            my_list = list()
1658            child_it = it
1659            for child in self.value:
1660                child_it, child_obj = child._decode_full(child_it)
1661                if child.key is not None:
1662                    child._add_if(my_list, child_obj, expect_key=True)
1663                else:
1664                    child._add_if(my_list, child_obj)
1665            ret = (child_it, self._construct_obj(my_list))
1666        elif self.type == "UNION":
1667            my_list = list()
1668            child_it = it
1669            found = False
1670            for child in self.value:
1671                try:
1672                    child_it, it_copy = tee(child_it)
1673                    child_it, child_obj = child._decode_full(child_it)
1674                    child._add_if(my_list, child_obj)
1675                    my_list.append(("union_choice", child.var_name()))
1676                    ret = (child_it, self._construct_obj(my_list))
1677                    found = True
1678                    break
1679                except CddlValidationError as c:
1680                    self.errors.append(str(c))
1681                    child_it = it_copy
1682            self._decode_assert(found, "No matches for union: " + str(self))
1683        else:
1684            ret = (it, self._decode_single_obj(self._iter_next(it)))
1685        return ret
1686
1687    def _decode_full(self, it):
1688        """Decode single CDDL value, with repetitions.
1689
1690        May consume 0 to n CBOR objects via the iterator.
1691        """
1692        if self.multi_var_condition():
1693            retvals = []
1694            for i in range(self.min_qty):
1695                it, retval = self._decode_obj(it)
1696                retvals.append(retval if not self.is_unambiguous_repeated() else None)
1697            try:
1698                for i in range(self.max_qty - self.min_qty):
1699                    it, it_copy = tee(it)
1700                    it, retval = self._decode_obj(it)
1701                    retvals.append(retval if not self.is_unambiguous_repeated() else None)
1702            except CddlValidationError as c:
1703                self.errors.append(str(c))
1704                it = it_copy
1705            return it, retvals
1706        else:
1707            ret = self._decode_obj(it)
1708            return ret
1709
1710    def decode_obj(self, obj):
1711        """CBOR object => python object"""
1712        it = iter([obj])
1713        try:
1714            _, decoded = self._decode_full(it)
1715            self._iter_is_empty(it)
1716        except CddlValidationError as e:
1717            if self.errors:
1718                print("Errors:")
1719                pprint(self.errors)
1720            raise e
1721        return decoded
1722
1723    def decode_str_yaml(self, yaml_str, yaml_compat=False):
1724        """YAML => python object"""
1725        yaml_obj = yaml_load(yaml_str)
1726        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1727        self.validate_obj(obj)
1728        return self.decode_obj(obj)
1729
1730    def decode_str(self, cbor_str):
1731        """CBOR bytestring => python object"""
1732        cbor_obj = loads(cbor_str)
1733        return self.decode_obj(cbor_obj)
1734
1735    def validate_obj(self, obj):
1736        """Validate CBOR object against CDDL. Exception if not valid."""
1737        self.decode_obj(obj)
1738        return True
1739
1740    def validate_str(self, cbor_str):
1741        """Validate CBOR bytestring against CDDL. Exception if not valid."""
1742        cbor_obj = loads(cbor_str)
1743        return self.validate_obj(cbor_obj)
1744
1745    def _from_yaml_obj(self, obj):
1746        """Convert object from YAML/JSON (with special dicts for bstr, tag etc) to CBOR object
1747        that cbor2 understands.
1748        """
1749        if isinstance(obj, list):
1750            if len(obj) == 1 and obj[0] == "zcbor_undefined":
1751                return undefined
1752            return [self._from_yaml_obj(elem) for elem in obj]
1753        elif isinstance(obj, dict):
1754            if ["zcbor_bstr"] == list(obj.keys()):
1755                if isinstance(obj["zcbor_bstr"], str):
1756                    bstr = bytes.fromhex(obj["zcbor_bstr"])
1757                else:
1758                    bstr = dumps(self._from_yaml_obj(obj["zcbor_bstr"]))
1759                return bstr
1760            elif ["zcbor_tag", "zcbor_tag_val"] == list(obj.keys()):
1761                return CBORTag(obj["zcbor_tag"], self._from_yaml_obj(obj["zcbor_tag_val"]))
1762            retval = dict()
1763            for key, val in obj.items():
1764                match = getrp(r"zcbor_keyval\d+").fullmatch(key)
1765                if match is not None:
1766                    new_key = self._from_yaml_obj(val["key"])
1767                    new_val = self._from_yaml_obj(val["val"])
1768                    if isinstance(new_key, list):
1769                        new_key = tuple(new_key)
1770                    retval[new_key] = new_val
1771                else:
1772                    retval[key] = self._from_yaml_obj(val)
1773            return retval
1774        return obj
1775
1776    def _to_yaml_obj(self, obj):
1777        """inverse of _from_yaml_obj"""
1778        if isinstance(obj, list) or isinstance(obj, tuple):
1779            return [self._to_yaml_obj(elem) for elem in obj]
1780        elif isinstance(obj, dict):
1781            retval = dict()
1782            i = 0
1783            for key, val in obj.items():
1784                if not isinstance(key, str):
1785                    retval[f"zcbor_keyval{i}"] = {
1786                        "key": self._to_yaml_obj(key), "val": self._to_yaml_obj(val)}
1787                    i += 1
1788                else:
1789                    retval[key] = self._to_yaml_obj(val)
1790            return retval
1791        elif isinstance(obj, bytes):
1792            f = BytesIO(obj)
1793            try:
1794                bstr_obj = self._to_yaml_obj(load(f))
1795            except (CBORDecodeValueError, CBORDecodeEOF):
1796                # failed decoding
1797                bstr_obj = obj.hex()
1798            else:
1799                if f.read(1) != b'':
1800                    # not fully decoded
1801                    bstr_obj = obj.hex()
1802            return {"zcbor_bstr": bstr_obj}
1803        elif isinstance(obj, CBORTag):
1804            return {"zcbor_tag": obj.tag, "zcbor_tag_val": self._to_yaml_obj(obj.value)}
1805        elif obj is undefined:
1806            return ["zcbor_undefined"]
1807        assert not isinstance(obj, bytes)
1808        return obj
1809
1810    def from_yaml(self, yaml_str, yaml_compat=False):
1811        """YAML str => CBOR bytestr"""
1812        yaml_obj = yaml_load(yaml_str)
1813        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1814        self.validate_obj(obj)
1815        return dumps(obj)
1816
1817    def obj_to_yaml(self, obj, yaml_compat=False):
1818        """CBOR object => YAML str"""
1819        self.validate_obj(obj)
1820        yaml_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1821        return yaml_dump(yaml_obj)
1822
1823    def str_to_yaml(self, cbor_str, yaml_compat=False):
1824        """CBOR bytestring => YAML str"""
1825        return self.obj_to_yaml(loads(cbor_str), yaml_compat=yaml_compat)
1826
1827    def from_json(self, json_str, yaml_compat=False):
1828        """JSON str => CBOR bytestr"""
1829        json_obj = json_load(json_str)
1830        obj = self._from_yaml_obj(json_obj) if yaml_compat else json_obj
1831        self.validate_obj(obj)
1832        return dumps(obj)
1833
1834    def obj_to_json(self, obj, yaml_compat=False):
1835        """CBOR object => JSON str"""
1836        self.validate_obj(obj)
1837        json_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1838        return json_dump(json_obj)
1839
1840    def str_to_json(self, cbor_str, yaml_compat=False):
1841        """CBOR bytestring => JSON str"""
1842        return self.obj_to_json(loads(cbor_str), yaml_compat=yaml_compat)
1843
1844    def str_to_c_code(self, cbor_str, var_name, columns=0):
1845        """CBOR bytestring => C code (uint8_t array initialization)"""
1846        arr = ", ".join(f"0x{c:02x}" for c in cbor_str)
1847        if columns:
1848            arr = '\n' + indent("\n".join(wrap(arr, 6 * columns)), '\t') + '\n'
1849        return f'uint8_t {var_name}[] = {{{arr}}};\n'
1850
1851
1852class XcoderTuple(NamedTuple):
1853    body: list
1854    func_name: str
1855    type_name: str
1856
1857
1858class CddlTypes(NamedTuple):
1859    my_types: dict
1860    my_control_groups: dict
1861
1862
1863class CodeGenerator(CddlXcoder):
1864    """Class for generating C code that encode/decodes CBOR and validates it according to the CDDL.
1865    """
1866    def __init__(self, mode, entry_type_names, default_bit_size, *args, **kwargs):
1867        super(CodeGenerator, self).__init__(*args, **kwargs)
1868        self.mode = mode
1869        self.entry_type_names = entry_type_names
1870        self.default_bit_size = default_bit_size
1871
1872    @classmethod
1873    def from_cddl(cddl_class, mode, *args, **kwargs):
1874        cddl_res = super(CodeGenerator, cddl_class).from_cddl(*args, **kwargs)
1875
1876        # set access prefix (struct access paths) for all the definitions.
1877        for my_type in cddl_res.my_types:
1878            cddl_res.my_types[my_type].set_access_prefix(f"(*{struct_ptr_name(mode)})")
1879
1880        return cddl_res
1881
1882    def is_entry_type(self):
1883        """Whether this element (an OTHER) refers to an entry type."""
1884        return (self.type == "OTHER") and (self.value in self.entry_type_names)
1885
1886    def is_cbor(self):
1887        """Whether to include a "cbor" variable for this element."""
1888        res = (self.type_name() is not None) and not self.is_entry_type() and (
1889            (self.type != "OTHER") or self.my_types[self.value].is_cbor())
1890        return res
1891
1892    def init_args(self):
1893        return (self.mode, self.entry_type_names, self.default_bit_size, self.default_max_qty)
1894
1895    def delegate_type_condition(self):
1896        """Whether to use the C type of the first child as this type's C type"""
1897        ret = self.skip_condition() and (self.multi_var_condition()
1898                                         or self.self_repeated_multi_var_condition()
1899                                         or self.range_check_condition()
1900                                         or (self in self.my_types.values()))
1901        return ret
1902
1903    def is_delegated_type(self):
1904        return self.is_delegated
1905
1906    def present_var(self):
1907        """Declaration of the "present" variable for this element."""
1908        return ["bool %s;" % self.present_var_name()]
1909
1910    def count_var(self):
1911        """Declaration of the "count" variable for this element."""
1912        return ["size_t %s;" % self.count_var_name()]
1913
1914    def anonymous_choice_var(self):
1915        """Declaration of the "choice" variable for this element."""
1916        int_vals = self.all_children_int_disambiguated()
1917        return self.enclose("enum", [val.enum_var(int_vals) + "," for val in self.value])
1918
1919    def choice_var(self):
1920        """Declaration of the "choice" variable for this element."""
1921        var = self.anonymous_choice_var()
1922        var[-1] += f" {self.choice_var_name()};"
1923        return var
1924
1925    def child_declarations(self):
1926        """Declaration of the variables of all children."""
1927        decl = [line for child in self.value for line in child.full_declaration()]
1928        return decl
1929
1930    def child_single_declarations(self):
1931        """Declaration of the variables of all children."""
1932        decl = list()
1933        for child in self.value:
1934            if not child.is_unambiguous_repeated():
1935                decl.extend(child.single_declaration())
1936        return decl
1937
1938    def simple_func_condition(self):
1939        if self.range_check_condition():
1940            return True
1941        if self.single_func_impl_condition():
1942            return True
1943        if self.type == "OTHER" and self.my_types[self.value].simple_func_condition():
1944            return True
1945        return False
1946
1947    def raw_type_name(self):
1948        """Base name if this element needs to declare a type."""
1949        return "struct %s" % self.id()
1950
1951    def enum_type_name(self):
1952        return "enum %s" % self.id()
1953
1954    def bit_size(self):
1955        """The bit width of the integers as represented in code."""
1956        bit_size = None
1957        if self.type in ["UINT", "INT", "NINT"]:
1958            assert self.default_bit_size in [32, 64], "The default_bit_size must be 32 or 64."
1959            if self.default_bit_size == 64:
1960                bit_size = 64
1961            else:
1962                bit_size = 32
1963
1964                for v in [self.value or 0, self.max_value or 0, self.min_value or 0]:
1965                    if (type(v) is str):
1966                        if "64" in v:
1967                            bit_size = 64
1968                    elif self.type == "UINT":
1969                        if (v > UINT32_MAX):
1970                            bit_size = 64
1971                    else:
1972                        if (v > INT32_MAX) or (v < INT32_MIN):
1973                            bit_size = 64
1974        return bit_size
1975
1976    def float_type(self):
1977        """If this is a floating point number, return the C type to use for it."""
1978        if self.type != "FLOAT":
1979            return None
1980
1981        max_size = self.max_size or 8
1982
1983        if max_size <= 4:
1984            return "float"
1985        elif max_size == 8:
1986            return "double"
1987        else:
1988            raise TypeError("Floats must have 4 or 8 bytes of precision.")
1989
1990    def val_type_name(self):
1991        """Name of the type of this element's actual value variable."""
1992        if self.multi_val_condition():
1993            return self.raw_type_name()
1994
1995        # Will fail runtime if we don't use lambda for type_name()
1996        # pylint: disable=unnecessary-lambda
1997        name = {
1998            "INT": lambda: f"int{self.bit_size()}_t",
1999            "UINT": lambda: f"uint{self.bit_size()}_t",
2000            "NINT": lambda: f"int{self.bit_size()}_t",
2001            "FLOAT": lambda: self.float_type(),
2002            "BSTR": lambda: "struct zcbor_string",
2003            "TSTR": lambda: "struct zcbor_string",
2004            "BOOL": lambda: "bool",
2005            "NIL": lambda: None,
2006            "UNDEF": lambda: None,
2007            "ANY": lambda: None,
2008            "LIST": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
2009            "MAP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
2010            "GROUP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
2011            "UNION": lambda: self.union_type(),
2012            "OTHER": lambda: self.my_types[self.value].type_name(),
2013        }[self.type]()
2014
2015        return name
2016
2017    def repeated_type_name(self):
2018        """Name of the type for the repeated part of this element.
2019
2020        I.e. the part that happens multiple times if the element has a quantifier.
2021        not including things like the "count" or "present" variable.
2022        """
2023        if self.self_repeated_multi_var_condition():
2024            name = self.raw_type_name()
2025            if self.val_type_name() == name:
2026                name = name + "_r"
2027        else:
2028            name = self.val_type_name()
2029        return name
2030
2031    def type_name(self):
2032        """Name of the type for this element."""
2033        if self.multi_var_condition():
2034            name = self.raw_type_name()
2035        else:
2036            name = self.repeated_type_name()
2037        return name
2038
2039    def add_var_name(self, var_type, full=False, anonymous=False):
2040        """Take a multi member type name and create a variable declaration.
2041
2042        Make it an array if the element is repeated.
2043        """
2044        if var_type:
2045            assert (var_type[-1][-1] == "}" or len(var_type) == 1), \
2046                f"Expected single var: {var_type!r}"
2047            if not anonymous or var_type[-1][-1] != "}":
2048                var_name = self.var_name()
2049                array_part = f"[{self.max_qty}]" if full and self.max_qty != 1 else ""
2050                var_type[-1] += f" {var_name}{array_part}"
2051            var_type = add_semicolon(var_type)
2052        return var_type
2053
2054    def var_type(self):
2055        """The type for this element as a member variable."""
2056        if not self.multi_val_condition() and self.val_type_name() is not None:
2057            return [self.val_type_name()]
2058        elif self.type == "UNION":
2059            return self.union_type()
2060        return []
2061
2062    def enclose(self, ingress, declaration):
2063        """Enclose a list of declarations in a block (struct, union or enum)."""
2064        if declaration:
2065            return [f"{ingress} {{"] + [indentation + line for line in declaration] + ["}"]
2066        else:
2067            return []
2068
2069    def union_type(self):
2070        """Type declaration for unions."""
2071        declaration = self.enclose("union", self.child_single_declarations())
2072        return declaration
2073
2074    def single_declaration(self):
2075        return self.add_var_name(self.single_var_type(), anonymous=True)
2076
2077    def repeated_declaration(self):
2078        """Declaration of the repeated part of this element."""
2079        if self.is_unambiguous_repeated():
2080            return []
2081
2082        var_type = self.var_type()
2083        multi_var = False
2084
2085        decl = []
2086
2087        if not self.skip_condition():
2088            decl += self.add_var_name(var_type, anonymous=(self.type == "UNION"))
2089
2090        if self.type in ["LIST", "MAP", "GROUP"]:
2091            decl += self.child_declarations()
2092            multi_var = len(decl) > 1
2093
2094        if self.reduced_key_var_condition():
2095            key_var = self.key.full_declaration()
2096            decl = key_var + decl
2097            multi_var = key_var != []
2098
2099        if self.choice_var_condition():
2100            choice_var = self.choice_var()
2101            decl += choice_var
2102            multi_var = choice_var != []
2103
2104        if self.cbor_var_condition():
2105            cbor_var = self.cbor.full_declaration()
2106            decl += cbor_var
2107            multi_var = cbor_var != []
2108
2109        return decl
2110
2111    def full_declaration(self):
2112        """Declaration of the full type for this element."""
2113        multi_var = False
2114
2115        if self.is_unambiguous():
2116            return []
2117
2118        if self.multi_var_condition():
2119            if self.is_unambiguous_repeated():
2120                decl = []
2121            else:
2122                decl = self.add_var_name(
2123                    [self.repeated_type_name()]
2124                    if self.repeated_type_name() is not None else [], full=True)
2125        else:
2126            decl = self.repeated_declaration()
2127
2128        if self.count_var_condition():
2129            count_var = self.count_var()
2130            decl += count_var
2131            multi_var = count_var != []
2132
2133        if self.present_var_condition():
2134            present_var = self.present_var()
2135            decl += present_var
2136            multi_var = present_var != []
2137
2138        assert multi_var == self.multi_var_condition()
2139
2140        return decl
2141
2142    def single_var_type(self, full=True):
2143        """Return the type definition of this element.
2144
2145        If there are multiple variables, wrap them in a
2146        struct so the function always returns a single type with no name. If full is False, only
2147        repeated part is used.
2148        """
2149        if full and self.multi_member():
2150            return self.enclose("struct", self.full_declaration())
2151        elif not full and self.repeated_multi_var_condition():
2152            return self.enclose("struct", self.repeated_declaration())
2153        else:
2154            return self.var_type()
2155
2156    def type_def(self):
2157        """Return the type definition of this element, and all its children + key + cbor."""
2158        ret_val = []
2159        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2160            ret_val.extend(
2161                [elem for typedef in [
2162                    child.type_def() for child in self.value] for elem in typedef])
2163        if self.bits:
2164            ret_val.extend(self.my_control_groups[self.bits].type_def_bits())
2165        if self.cbor_var_condition():
2166            ret_val.extend(self.cbor.type_def())
2167        if self.reduced_key_var_condition():
2168            ret_val.extend(self.key.type_def())
2169        if self.type == "OTHER":
2170            ret_val.extend(self.my_types[self.value].type_def())
2171        if self.repeated_type_def_condition():
2172            type_def_list = self.single_var_type(full=False)
2173            if type_def_list:
2174                ret_val.extend([(self.single_var_type(full=False), self.repeated_type_name())])
2175        if self.type_def_condition():
2176            type_def_list = self.single_var_type()
2177            if type_def_list:
2178                ret_val.extend([(self.single_var_type(), self.type_name())])
2179        return ret_val
2180
2181    def type_def_bits(self):
2182        tdef = self.anonymous_choice_var()
2183        return [(tdef, self.enum_type_name())]
2184
2185    def float_prefix(self):
2186        if self.type != "FLOAT":
2187            return ""
2188
2189        min_size = self.min_size or 2
2190        max_size = self.max_size or 8
2191
2192        if max_size == 2:
2193            return "float16"
2194        elif min_size == 2 and max_size == 4:
2195            return "float16_32" if self.mode == "decode" else "float32"
2196        if min_size == 4 and max_size == 4:
2197            return "float32"
2198        elif min_size == 4 and max_size == 8:
2199            return "float32_64" if self.mode == "decode" else "float64"
2200        elif min_size == 8 and max_size == 8:
2201            return "float64"
2202        elif min_size <= 4 and max_size == 8:
2203            return "float" if self.mode == "decode" else "float64"
2204        else:
2205            raise TypeError("Floats must have 2, 4 or 8 bytes of precision.")
2206
2207    def single_func_prim_prefix(self):
2208        if self.type == "OTHER":
2209            return self.my_types[self.value].single_func_prim_prefix()
2210        return ({
2211            "INT": f"zcbor_int{self.bit_size()}",
2212            "UINT": f"zcbor_uint{self.bit_size()}",
2213            "NINT": f"zcbor_int{self.bit_size()}",
2214            "FLOAT": f"zcbor_{self.float_prefix()}",
2215            "BSTR": f"zcbor_bstr",
2216            "TSTR": f"zcbor_tstr",
2217            "BOOL": f"zcbor_bool",
2218            "NIL": f"zcbor_nil",
2219            "UNDEF": f"zcbor_undefined",
2220            "ANY": f"zcbor_any",
2221        }[self.type])
2222
2223    def xcode_func_name(self):
2224        """Name of the encoder/decoder function for this element."""
2225        return f"{self.mode}_{self.var_name(with_prefix=True, observe_skipped=False)}"
2226
2227    def repeated_xcode_func_name(self):
2228        """Name of the encoder/decoder function for the repeated part of this element."""
2229        return f"{self.mode}_repeated_{self.var_name(with_prefix=True, observe_skipped=False)}"
2230
2231    def single_func_prim_name(self, union_int=None, ptr_result=False):
2232        """Function name for xcoding this type, when it is a primitive type"""
2233        ptr_variant = ptr_result and self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]
2234        func_prefix = self.single_func_prim_prefix()
2235        if self.mode == "decode":
2236            if self.type == "ANY":
2237                func = "zcbor_any_skip"
2238            elif not self.is_unambiguous_value():
2239                func = f"{func_prefix}_decode"
2240            elif not union_int:
2241                func = f"{func_prefix}_{'pexpect' if ptr_variant else 'expect'}"
2242            elif union_int == "EXPECT":
2243                assert not ptr_variant, \
2244                       "Programmer error: invalid use of expect_union."
2245                func = f"{func_prefix}_expect_union"
2246            elif union_int == "DROP":
2247                return None
2248        else:
2249            if self.type == "ANY":
2250                func = "zcbor_nil_put"
2251            elif (not self.is_unambiguous_value()) or self.type in ["TSTR", "BSTR"] or ptr_variant:
2252                func = f"{func_prefix}_encode"
2253            else:
2254                func = f"{func_prefix}_put"
2255        return func
2256
2257    def single_func_prim(self, access, union_int=None, ptr_result=False):
2258        """Return the function name and arguments to call to encode/decode this element.
2259
2260        Only used when this element DOESN'T define its own encoder/decoder function (when it's a
2261        primitive type, for which functions already exist, or when the function is defined elsewhere
2262        ("OTHER"))
2263        """
2264        assert self.type not in ["LIST", "MAP"], "Must have wrapper function for list or map."
2265
2266        if self.type == "GROUP":
2267            assert len(self.value) == 0, "Group should have no children to get here."
2268            return (None, None)
2269
2270        if self.type == "OTHER":
2271            return self.my_types[self.value].single_func(access, union_int)
2272
2273        func_name = self.single_func_prim_name(union_int, ptr_result=ptr_result)
2274        if func_name is None:
2275            return (None, None)
2276
2277        if self.type in ["NIL", "UNDEF", "ANY"]:
2278            arg = "NULL"
2279        elif not self.is_unambiguous_value():
2280            arg = deref_if_not_null(access)
2281        elif self.type in ["BSTR", "TSTR"]:
2282            arg = tmp_str_or_null(self.value)
2283        elif self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]:
2284            value = val_to_str(self.value)
2285            arg = (f"&({self.val_type_name()}){{{value}}}" if ptr_result else value)
2286        else:
2287            assert False, "Should not come here."
2288
2289        return (func_name, arg)
2290
2291    def single_func(self, access=None, union_int=None):
2292        """Return the function name and arguments to call to encode/decode this element."""
2293        if self.single_func_impl_condition():
2294            return (self.xcode_func_name(), deref_if_not_null(access or self.var_access()))
2295        else:
2296            return self.single_func_prim(access or self.val_access(), union_int)
2297
2298    def repeated_single_func(self, ptr_result=False):
2299        """Return the function name and arguments to call to encode/decode the repeated
2300        part of this element.
2301        """
2302        if self.repeated_single_func_impl_condition():
2303            return (self.repeated_xcode_func_name(), deref_if_not_null(self.repeated_val_access()))
2304        else:
2305            return self.single_func_prim(self.repeated_val_access(), ptr_result=ptr_result)
2306
2307    def has_backup(self):
2308        return (self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"])
2309
2310    def num_backups(self):
2311        total = 0
2312        if self.key:
2313            total += self.key.num_backups()
2314        if self.cbor_var_condition():
2315            total += self.cbor.num_backups()
2316        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2317            total += max([child.num_backups() for child in self.value] + [0])
2318        if self.type == "OTHER":
2319            total += self.my_types[self.value].num_backups()
2320        if self.has_backup():
2321            total += 1
2322        return total
2323
2324    def depends_on(self):
2325        """Return a number indicating how many other elements this element depends on.
2326
2327        Used for putting functions and typedefs in the right order.
2328        """
2329        ret_vals = [1]
2330
2331        if not self.dependsOnCall:
2332            self.dependsOnCall = True
2333            if self.cbor_var_condition():
2334                ret_vals.append(self.cbor.depends_on())
2335            if self.key:
2336                ret_vals.append(self.key.depends_on())
2337            if self.type == "OTHER":
2338                ret_vals.append(1 + self.my_types[self.value].depends_on())
2339            if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2340                ret_vals.extend(child.depends_on() for child in self.value)
2341            self.dependsOnCall = False
2342
2343        return max(ret_vals)
2344
2345    def xcode_single_func_prim(self, union_int=None):
2346        """Make a string from the list returned by single_func_prim()"""
2347        return xcode_statement(*self.single_func_prim(self.val_access(), union_int))
2348
2349    def list_counts(self):
2350        """Recursively sum the total minimum and maximum element count for this element."""
2351        retval = ({
2352            "INT": lambda: (self.min_qty, self.max_qty),
2353            "UINT": lambda: (self.min_qty, self.max_qty),
2354            "NINT": lambda: (self.min_qty, self.max_qty),
2355            "FLOAT": lambda: (self.min_qty, self.max_qty),
2356            "BSTR": lambda: (self.min_qty, self.max_qty),
2357            "TSTR": lambda: (self.min_qty, self.max_qty),
2358            "BOOL": lambda: (self.min_qty, self.max_qty),
2359            "NIL": lambda: (self.min_qty, self.max_qty),
2360            "UNDEF": lambda: (self.min_qty, self.max_qty),
2361            "ANY": lambda: (self.min_qty, self.max_qty),
2362            # Lists are their own element
2363            "LIST": lambda: (self.min_qty, self.max_qty),
2364            # Maps are their own element
2365            "MAP": lambda: (self.min_qty, self.max_qty),
2366            "GROUP": lambda: (self.min_qty * sum((child.list_counts()[0] for child in self.value)),
2367                              self.max_qty * sum((child.list_counts()[1] for child in self.value))),
2368            "UNION": lambda: (self.min_qty * min((child.list_counts()[0] for child in self.value)),
2369                              self.max_qty * max((child.list_counts()[1] for child in self.value))),
2370            "OTHER": lambda: (self.min_qty * self.my_types[self.value].list_counts()[0],
2371                              self.max_qty * self.my_types[self.value].list_counts()[1]),
2372        }[self.type]())
2373        return retval
2374
2375    def xcode_list(self):
2376        """Return the full code needed to encode/decode a "LIST" or "MAP" element with children."""
2377        start_func = f"zcbor_{self.type.lower()}_start_{self.mode}"
2378        end_func = f"zcbor_{self.type.lower()}_end_{self.mode}"
2379        end_func_force = f"zcbor_list_map_end_force_{self.mode}"
2380        assert start_func in [
2381            "zcbor_list_start_decode", "zcbor_list_start_encode",
2382            "zcbor_map_start_decode", "zcbor_map_start_encode"]
2383        assert end_func in [
2384            "zcbor_list_end_decode", "zcbor_list_end_encode",
2385            "zcbor_map_end_decode", "zcbor_map_end_encode"]
2386        assert self.type in ["LIST", "MAP"], \
2387            "Expected LIST or MAP type, was %s." % self.type
2388        _, max_counts = zip(
2389            *(child.list_counts() for child in self.value)) if self.value else ((0,), (0,))
2390        count_arg = f', {str(sum(max_counts))}' if self.mode == 'encode' else ''
2391        with_children = "(%s && ((%s) || (%s, false)) && %s)" % (
2392            f"{start_func}(state{count_arg})",
2393            f"{newl_ind}&& ".join(child.full_xcode() for child in self.value),
2394            f"{end_func_force}(state)",
2395            f"{end_func}(state{count_arg})")
2396        without_children = "(%s && %s)" % (
2397            f"{start_func}(state{count_arg})",
2398            f"{end_func}(state{count_arg})")
2399        return with_children if len(self.value) > 0 else without_children
2400
2401    def xcode_group(self, union_int=None):
2402        """Return the full code needed to encode/decode a "GROUP" element's children."""
2403        assert self.type in ["GROUP"], "Expected GROUP type."
2404        return "(%s)" % (newl_ind + "&& ").join(
2405            [self.value[0].full_xcode(union_int)]
2406            + [child.full_xcode() for child in self.value[1:]])
2407
2408    def xcode_union(self):
2409        """Return the full code needed to encode/decode a "UNION" element's children."""
2410        assert self.type in ["UNION"], "Expected UNION type."
2411        if self.mode == "decode":
2412            if self.all_children_int_disambiguated():
2413                lines = []
2414                lines.extend(
2415                    ["((%s == %s) && (%s))" %
2416                        (self.choice_var_access(), child.enum_var_name(),
2417                            child.full_xcode(union_int="DROP"))
2418                        for child in self.value])
2419                bit_size = self.value[0].bit_size()
2420                func = f"zcbor_uint_{self.mode}" if self.all_children_uint_disambiguated() else \
2421                       f"zcbor_int_{self.mode}"
2422                return "((%s) && (%s))" % (
2423                    f"({func}(state, &{self.choice_var_access()}, "
2424                    + f"sizeof({self.choice_var_access()})))",
2425                    "((" + f"{newl_ind}|| ".join(lines)
2426                         + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_VALUE), false))",)
2427
2428            child_values = ["(%s && ((%s = %s), true))" %
2429                            (child.full_xcode(
2430                                union_int="EXPECT" if child.is_int_disambiguated() else None),
2431                                self.choice_var_access(), child.enum_var_name())
2432                            for child in self.value]
2433
2434            # Reset state for all but the first child.
2435            for i in range(1, len(child_values)):
2436                if ((not self.value[i].is_int_disambiguated())
2437                        and self.value[i - 1].simple_func_condition()):
2438                    child_values[i] = f"(zcbor_union_elem_code(state) && {child_values[i]})"
2439
2440            return "(%s && (int_res = (%s), %s, int_res))" \
2441                % ("zcbor_union_start_code(state)",
2442                   f"{newl_ind}|| ".join(child_values),
2443                   "zcbor_union_end_code(state)")
2444        else:
2445            return ternary_if_chain(
2446                self.choice_var_access(),
2447                [child.enum_var_name() for child in self.value],
2448                [child.full_xcode() for child in self.value])
2449
2450    def xcode_bstr(self):
2451        if self.cbor and not self.cbor.is_entry_type():
2452            access_arg = f', {deref_if_not_null(self.val_access())}' if self.mode == 'decode' \
2453                else ''
2454            res_arg = f', &tmp_str' if self.mode == 'encode' \
2455                else ''
2456            xcode_cbor = "(%s)" % ((newl_ind + "&& ").join(
2457                [f"zcbor_bstr_start_{self.mode}(state{access_arg})",
2458                 f"(int_res = ({self.cbor.full_xcode()}), "
2459                 f"zcbor_bstr_end_{self.mode}(state{res_arg}), int_res)"]))
2460            if self.mode == "decode" or self.is_unambiguous():
2461                return xcode_cbor
2462            else:
2463                return f"({self.val_access()}.value " \
2464                    f"? (memcpy(&tmp_str, &{self.val_access()}, sizeof(tmp_str)), " \
2465                    f"{self.xcode_single_func_prim()}) : ({xcode_cbor}))"
2466        return self.xcode_single_func_prim()
2467
2468    def xcode_tags(self):
2469        return [f"zcbor_tag_{'put' if (self.mode == 'encode') else 'expect'}(state, {tag})"
2470                for tag in self.tags]
2471
2472    def value_suffix(self, value_str):
2473        """Appends ULL or LL if a value exceeding 32-bits is used"""
2474        if not value_str.isdigit():
2475            return ""
2476        value = int(value_str)
2477        if self.type == "INT" or self.type == "NINT":
2478            if value > INT32_MAX or value <= INT32_MIN:
2479                return "LL"
2480        elif self.type == "UINT":
2481            if value > UINT32_MAX:
2482                return "ULL"
2483
2484        return ""
2485
2486    def range_checks(self, access):
2487        """Return the code needed to check the size/value bounds of this element."""
2488        if self.type != "OTHER" and self.value is not None:
2489            return []
2490
2491        range_checks = []
2492
2493        # Remove unneeded checks when the bounds are (U)INT64_(MIN|MAX)
2494        exc_vals = [UINT64_MAX, 0] if self.type == "UINT" else [INT64_MAX, INT64_MIN]
2495        min_val = self.min_value if self.min_value not in exc_vals else None
2496        max_val = self.max_value if self.max_value not in exc_vals else None
2497
2498        if self.type in ["INT", "UINT", "NINT", "FLOAT", "BOOL"]:
2499            if min_val is not None and min_val == max_val:
2500                range_checks.append(f"({access} == {val_to_str(min_val)}"
2501                                    f"{self.value_suffix(val_to_str(min_val))})")
2502            else:
2503                if min_val is not None:
2504                    range_checks.append(f"({access} >= {val_to_str(min_val)}"
2505                                        f"{self.value_suffix(val_to_str(min_val))})")
2506                if max_val is not None:
2507                    range_checks.append(f"({access} <= {val_to_str(max_val)}"
2508                                        f"{self.value_suffix(val_to_str(max_val))})")
2509            if self.bits:
2510                range_checks.append(
2511                    f"!({access} & ~("
2512                    + ' | '.join([f'(1 << {c.enum_var_name()})'
2513                                 for c in self.my_control_groups[self.bits].value])
2514                    + "))")
2515        elif self.type in ["BSTR", "TSTR"]:
2516            if self.min_size is not None and self.min_size == self.max_size:
2517                range_checks.append(f"({access}.len == {val_to_str(self.min_size)})")
2518            else:
2519                if self.min_size is not None:
2520                    range_checks.append(f"({access}.len >= {val_to_str(self.min_size)})")
2521                if self.max_size is not None:
2522                    range_checks.append(f"({access}.len <= {val_to_str(self.max_size)})")
2523        elif self.type == "OTHER":
2524            if not self.my_types[self.value].single_func_impl_condition():
2525                range_checks.extend(self.my_types[self.value].range_checks(access))
2526
2527        if range_checks:
2528            range_checks[0] = "((" + range_checks[0]
2529            range_checks[-1] = range_checks[-1] \
2530                + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_RANGE), false))"
2531
2532        return range_checks
2533
2534    def repeated_xcode(self, union_int=None):
2535        """Return the full code needed to encode/decode this element.
2536
2537        Including children, key and cbor, excluding repetitions.
2538        """
2539        val_union_int = union_int if not self.key else None  # In maps, only pass union_int to key.
2540        range_checks = self.range_checks(self.val_access())
2541        xcoder = {
2542            "INT": self.xcode_single_func_prim,
2543            "UINT": lambda: self.xcode_single_func_prim(val_union_int),
2544            "NINT": lambda: self.xcode_single_func_prim(val_union_int),
2545            "FLOAT": self.xcode_single_func_prim,
2546            "BSTR": self.xcode_bstr,
2547            "TSTR": self.xcode_single_func_prim,
2548            "BOOL": self.xcode_single_func_prim,
2549            "NIL": self.xcode_single_func_prim,
2550            "UNDEF": self.xcode_single_func_prim,
2551            "ANY": self.xcode_single_func_prim,
2552            "LIST": self.xcode_list,
2553            "MAP": self.xcode_list,
2554            "GROUP": lambda: self.xcode_group(val_union_int),
2555            "UNION": self.xcode_union,
2556            "OTHER": lambda: self.xcode_single_func_prim(val_union_int),
2557        }[self.type]
2558        xcoders = []
2559        if self.key:
2560            xcoders.append(self.key.full_xcode(union_int))
2561        if self.tags:
2562            xcoders.extend(self.xcode_tags())
2563        if self.mode == "decode":
2564            xcoders.append(xcoder())
2565            xcoders.extend(range_checks)
2566        elif self.type == "BSTR" and self.cbor:
2567            xcoders.append(xcoder())
2568            xcoders.extend(self.range_checks("tmp_str"))
2569        else:
2570            xcoders.extend(range_checks)
2571            xcoders.append(xcoder())
2572
2573        return "(%s)" % ((newl_ind + "&& ").join(xcoders),)
2574
2575    def result_len(self):
2576        """Code for the size of the repeated part of this element."""
2577        if self.repeated_type_name() is None or self.is_unambiguous_repeated():
2578            return "0"
2579        else:
2580            return "sizeof(%s)" % self.repeated_type_name()
2581
2582    def full_xcode(self, union_int=None):
2583        """Return the full code needed to encode/decode this element.
2584
2585        Including children, key, cbor, and repetitions.
2586        """
2587        if self.present_var_condition():
2588            if self.mode == "encode":
2589                func, *arguments = self.repeated_single_func(ptr_result=False)
2590                return f"(!{self.present_var_access()} || {func}({xcode_args(*arguments)}))"
2591            else:
2592                assert self.mode == "decode", \
2593                    f"This code needs self.mode to be 'decode', not {self.mode}."
2594                if not self.repeated_single_func_impl_condition():
2595                    decode_str = self.repeated_xcode(union_int)
2596                    return f"({self.present_var_access()} = {self.repeated_xcode(union_int)}, 1)"
2597                func, *arguments = self.repeated_single_func(ptr_result=True)
2598                return (
2599                    f"zcbor_present_decode(&(%s), (zcbor_decoder_t *)%s, %s)" %
2600                    (self.present_var_access(), func, xcode_args(*arguments),))
2601        elif self.count_var_condition():
2602            func, arg = self.repeated_single_func(ptr_result=True)
2603
2604            minmax = "_minmax" if self.mode == "encode" else ""
2605            mode = self.mode
2606            return (
2607                f"zcbor_multi_{mode}{minmax}(%s, %s, &%s, (zcbor_{mode}r_t *)%s, %s, %s)" %
2608                (self.min_qty,
2609                 self.max_qty,
2610                 self.count_var_access(),
2611                 func,
2612                 xcode_args("*" + arg if arg != "NULL" and self.result_len() != "0" else arg),
2613                 self.result_len()))
2614        else:
2615            return self.repeated_xcode(union_int)
2616
2617    def xcode(self):
2618        """Return the body of the encoder/decoder function for this element."""
2619        return self.full_xcode()
2620
2621    def xcoders(self):
2622        """Recursively return a list of the bodies of the encoder/decoder functions for
2623        this element and its children + key + cbor.
2624        """
2625        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2626            for child in self.value:
2627                for xcoder in child.xcoders():
2628                    yield xcoder
2629        if self.cbor:
2630            for xcoder in self.cbor.xcoders():
2631                yield xcoder
2632        if self.key:
2633            for xcoder in self.key.xcoders():
2634                yield xcoder
2635        if self.type == "OTHER" and self.value not in self.entry_type_names:
2636            for xcoder in self.my_types[self.value].xcoders():
2637                yield xcoder
2638        if self.repeated_single_func_impl_condition():
2639            yield XcoderTuple(
2640                self.repeated_xcode(), self.repeated_xcode_func_name(), self.repeated_type_name())
2641        if (self.single_func_impl_condition()):
2642            xcode_body = self.xcode()
2643            yield XcoderTuple(xcode_body, self.xcode_func_name(), self.type_name())
2644
2645    def public_xcode_func_sig(self):
2646        type_name = self.type_name() if struct_ptr_name(self.mode) in self.full_xcode() else "void"
2647        return f"""
2648int cbor_{self.xcode_func_name()}(
2649		{"const " if self.mode == "decode" else ""}uint8_t *payload, size_t payload_len,
2650		{"" if self.mode == "decode" else "const "}{type_name} *{struct_ptr_name(self.mode)},
2651		{"size_t *payload_len_out"})"""
2652
2653
2654class CodeRenderer():
2655    def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', file_header=''):
2656        self.entry_types = entry_types
2657        self.print_time = print_time
2658        self.default_max_qty = default_max_qty
2659
2660        self.sorted_types = dict()
2661        self.functions = dict()
2662        self.type_defs = dict()
2663
2664        # Sort type definitions so the typedefs will come in the correct order in the header file
2665        # and the function in the correct order in the c file.
2666        for mode in modes:
2667            self.sorted_types[mode] = list(sorted(
2668                self.entry_types[mode], key=lambda _type: _type.depends_on(), reverse=False))
2669
2670            self.functions[mode] = self.unique_funcs(mode)
2671            self.functions[mode] = self.used_funcs(mode)
2672            self.type_defs[mode] = self.unique_types(mode)
2673
2674        self.version = __version__
2675
2676        if git_sha:
2677            self.version += f'-{git_sha}'
2678
2679        self.file_header = file_header.strip() + "\n\n" if file_header.strip() else ""
2680        self.file_header += f"""Generated using zcbor version {self.version}
2681https://github.com/NordicSemiconductor/zcbor{'''
2682at: ''' + datetime.now().strftime('%Y-%m-%d %H:%M:%S') if self.print_time else ''}
2683Generated with a --default-max-qty of {self.default_max_qty}"""
2684
2685    def header_guard(self, file_name):
2686        return path.basename(file_name).replace(".", "_").replace("-", "_").upper() + "__"
2687
2688    def unique_types(self, mode):
2689        """Return a list of typedefs for all defined types, with duplicate typedefs
2690        removed.
2691        """
2692        type_names = {}
2693        out_types = []
2694        for mtype in self.sorted_types[mode]:
2695            for type_def in mtype.type_def():
2696                type_name = type_def[1]
2697                if type_name not in type_names.keys():
2698                    type_names[type_name] = type_def[0]
2699                    out_types.append(type_def)
2700                else:
2701                    assert (''.join(type_names[type_name]) == ''.join(type_def[0])), f"""
2702Two elements share the type name {type_name}, but their implementations are not identical.
2703Please change one or both names. They are
2704{linesep.join(type_names[type_name])}
2705and
2706{linesep.join(type_def[0])}"""
2707        return out_types
2708
2709    def unique_funcs(self, mode):
2710        """Return a list of encoder/decoder functions for all defined types, with duplicate
2711        functions removed.
2712        """
2713        func_names = {}
2714        out_types = []
2715        for mtype in self.sorted_types[mode]:
2716            xcoders = list(mtype.xcoders())
2717            for funcType in xcoders:
2718                func_xcode = funcType[0]
2719                func_name = funcType[1]
2720                if func_name not in func_names.keys():
2721                    func_names[func_name] = funcType
2722                    out_types.append(funcType)
2723                elif func_name in func_names.keys():
2724                    assert func_names[func_name][0] == func_xcode, \
2725                        ("Two elements share the function name %s, but their implementations are "
2726                            + "not identical. Please change one or both names.\n\n%s\n\n%s") % \
2727                        (func_name, func_names[func_name][0], func_xcode)
2728
2729        return out_types
2730
2731    def used_funcs(self, mode):
2732        """Return a list of encoder/decoder functions for all defined types, with unused
2733        functions removed.
2734        """
2735        mod_entry_types = [
2736            XcoderTuple(
2737                func_type.xcode(),
2738                func_type.xcode_func_name(),
2739                func_type.type_name()) for func_type in self.entry_types[mode]]
2740        out_types = [func_type for func_type in mod_entry_types]
2741        full_code = "".join([func_type[0] for func_type in mod_entry_types])
2742        for func_type in reversed(self.functions[mode]):
2743            func_name = func_type[1]
2744            if func_type not in mod_entry_types and getrp(r"%s\W" % func_name).search(full_code):
2745                full_code += func_type[0]
2746                out_types.append(func_type)
2747        return list(reversed(out_types))
2748
2749    def render_forward_declaration(self, xcoder, mode):
2750        """Render a single decoding function with signature and body."""
2751        return f"""
2752static bool {xcoder.func_name}(zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2753            xcoder.type_name
2754            if struct_ptr_name(mode) in xcoder.body else "void"} *{struct_ptr_name(mode)});
2755            """.strip()
2756
2757    def render_function(self, xcoder, mode):
2758        body = xcoder.body
2759
2760        # Define the subroutine "paren" that matches parenthesised expressions.
2761        paren_re = r'(?(DEFINE)(?P<paren>\(((?>[^\(\)]+|(?&paren))*)\)))'
2762        # This uses "paren" to match a single argument to a function.
2763        arg_re = rf'([^,\(\)]|(?&paren))+'
2764        # Match a function pointer argument to a function.
2765        func_re = rf'\(zcbor_(en|de)coder_t \*\)(?P<func>{arg_re})'
2766        # Match a triplet of function pointer, state arg, and result arg.
2767        call_re = rf'{func_re}, (?P<state>{arg_re}), (?P<arg>{arg_re})'
2768        multi_re = rf'{paren_re}zcbor_multi_(en|de)code\(({arg_re},){{3}} {call_re}'
2769        present_re = rf'{paren_re}zcbor_present_(en|de)code\({arg_re}, {call_re}\)'
2770        map_re = rf'{paren_re}zcbor_unordered_map_search\({call_re}\)'
2771        all_funcs = chain(getrp(multi_re).finditer(body),
2772                          getrp(present_re).finditer(body),
2773                          getrp(map_re).finditer(body))
2774        arg_test = ""
2775        calls = ("\n		".join(
2776            (f"{m.group('func')}({m.group('state')}, {m.group('arg')});" for m in (all_funcs))))
2777        if calls != "":
2778            arg_test = f"""
2779	if (false) {{
2780		/* For testing that the types of the arguments are correct.
2781		 * A compiler error here means a bug in zcbor.
2782		 */
2783		{calls}
2784	}}
2785"""
2786        return f"""
2787static bool {xcoder.func_name}(
2788		zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2789            xcoder.type_name
2790            if struct_ptr_name(mode) in body else "void"} *{struct_ptr_name(mode)})
2791{{
2792	zcbor_log("%s\\r\\n", __func__);
2793	{"struct zcbor_string tmp_str;" if "tmp_str" in body else ""}
2794	{"bool int_res;" if "int_res" in body else ""}
2795
2796	bool res = ({body});
2797{arg_test}
2798	log_result(state, res, __func__);
2799	return res;
2800}}""".replace("	\n", "")  # call replace() to remove empty lines.
2801
2802    def render_entry_function(self, xcoder, mode):
2803        """Render a single entry function (API function) with signature and body."""
2804        func_name, func_arg = (xcoder.xcode_func_name(), struct_ptr_name(mode))
2805        return f"""
2806{xcoder.public_xcode_func_sig()}
2807{{
2808	zcbor_state_t states[{xcoder.num_backups() + 2}];
2809
2810	return zcbor_entry_function(payload, payload_len, (void *){func_arg}, payload_len_out, states,
2811		(zcbor_decoder_t *){func_name}, sizeof(states) / sizeof(zcbor_state_t), {
2812            xcoder.list_counts()[1]});
2813}}"""
2814
2815    def render_file_header(self, line_prefix):
2816        lp = line_prefix
2817        return (f"\n{lp} " + self.file_header.replace("\n", f"\n{lp} ")).replace(" \n", "\n")
2818
2819    def render_c_file(self, header_file_name, mode):
2820        """Render the entire generated C file contents."""
2821        log_result_define = """#define log_result(state, result, func) \
2822do { \\
2823	if (!result) { \\
2824		zcbor_trace_file(state); \\
2825		zcbor_log("%s error: %s\\r\\n", func, zcbor_error_str(zcbor_peek_error(state))); \\
2826	} else { \\
2827		zcbor_log("%s success\\r\\n", func); \\
2828	} \\
2829} while(0)"""
2830        return f"""/*{self.render_file_header(" *")}
2831 */
2832
2833#include <stdint.h>
2834#include <stdbool.h>
2835#include <stddef.h>
2836#include <string.h>
2837#include "zcbor_{mode}.h"
2838#include "{header_file_name}"
2839#include "zcbor_print.h"
2840
2841#if DEFAULT_MAX_QTY != {self.default_max_qty}
2842#error "The type file was generated with a different default_max_qty than this file"
2843#endif
2844
2845{log_result_define}
2846
2847{linesep.join([self.render_forward_declaration(xcoder, mode) for xcoder in self.functions[mode]])}
2848
2849{linesep.join([self.render_function(xcoder, mode) for xcoder in self.functions[mode]])}
2850
2851{linesep.join([self.render_entry_function(xcoder, mode) for xcoder in self.entry_types[mode]])}
2852"""
2853
2854    def render_h_file(self, type_def_file, header_guard, mode):
2855        """Render the entire generated header file contents."""
2856        return \
2857            f"""/*{self.render_file_header(" *")}
2858 */
2859
2860#ifndef {header_guard}
2861#define {header_guard}
2862
2863#include <stdint.h>
2864#include <stdbool.h>
2865#include <stddef.h>
2866#include <string.h>
2867#include "{type_def_file}"
2868
2869#ifdef __cplusplus
2870extern "C" {{
2871#endif
2872
2873#if DEFAULT_MAX_QTY != {self.default_max_qty}
2874#error "The type file was generated with a different default_max_qty than this file"
2875#endif
2876
2877{(linesep * 2).join([f"{xcoder.public_xcode_func_sig()};" for xcoder in self.entry_types[mode]])}
2878
2879
2880#ifdef __cplusplus
2881}}
2882#endif
2883
2884#endif /* {header_guard} */
2885"""
2886
2887    def render_type_file(self, header_guard, mode):
2888        body = (
2889            linesep + linesep).join(
2890                [f"{typedef[1]} {{{linesep}{linesep.join(typedef[0][1:])};"
2891                    for typedef in self.type_defs[mode]])
2892        return \
2893            f"""/*{self.render_file_header(" *")}
2894 */
2895
2896#ifndef {header_guard}
2897#define {header_guard}
2898
2899#include <stdint.h>
2900#include <stdbool.h>
2901#include <stddef.h>
2902{'#include <zcbor_common.h>' if "struct zcbor_string" in body else ""}
2903
2904#ifdef __cplusplus
2905extern "C" {{
2906#endif
2907
2908/** Which value for --default-max-qty this file was created with.
2909 *
2910 *  The define is used in the other generated file to do a build-time
2911 *  compatibility check.
2912 *
2913 *  See `zcbor --help` for more information about --default-max-qty
2914 */
2915#define DEFAULT_MAX_QTY {self.default_max_qty}
2916
2917{body}
2918
2919#ifdef __cplusplus
2920}}
2921#endif
2922
2923#endif /* {header_guard} */
2924"""
2925
2926    def render_cmake_file(self, target_name, h_files, c_files, type_file,
2927                          output_c_dir, output_h_dir, cmake_dir):
2928        include_dirs = sorted(set(((Path(output_h_dir)),
2929                                  (Path(type_file.name).parent),
2930                                  *((Path(h.name).parent) for h in h_files.values()))))
2931
2932        def relativify(p):
2933            try:
2934                return PurePosixPath(
2935                    Path("${CMAKE_CURRENT_LIST_DIR}") / path.relpath(Path(p), cmake_dir))
2936            except ValueError:
2937                # On Windows, the above will fail if the paths are on different drives.
2938                return Path(p).absolute().as_posix()
2939        return \
2940            f"""\
2941#{self.render_file_header("#")}
2942#
2943
2944add_library({target_name})
2945target_sources({target_name} PRIVATE
2946    {relativify(Path(output_c_dir, "zcbor_decode.c"))}
2947    {relativify(Path(output_c_dir, "zcbor_encode.c"))}
2948    {relativify(Path(output_c_dir, "zcbor_common.c"))}
2949    {relativify(Path(output_c_dir, "zcbor_print.c"))}
2950    {(linesep + "    ").join(((str(relativify(c.name))) for c in c_files.values()))}
2951    )
2952target_include_directories({target_name} PUBLIC
2953    {(linesep + "    ").join(((str(relativify(f)) for f in include_dirs)))}
2954    )
2955"""
2956
2957    def render(self, modes, h_files, c_files, type_file, include_prefix, cmake_file=None,
2958               output_c_dir=None, output_h_dir=None):
2959        for mode in modes:
2960            h_name = Path(include_prefix, Path(h_files[mode].name).name)
2961
2962            # Create and populate the generated c and h file.
2963            makedirs(path.dirname(Path(c_files[mode].name).absolute()), exist_ok=True)
2964
2965            type_def_name = Path(include_prefix, Path(type_file.name).name)
2966
2967            print("Writing to " + c_files[mode].name)
2968            c_files[mode].write(self.render_c_file(h_name, mode))
2969
2970            print("Writing to " + h_files[mode].name)
2971            h_files[mode].write(self.render_h_file(
2972                type_def_name,
2973                self.header_guard(h_files[mode].name), mode))
2974
2975        print("Writing to " + type_file.name)
2976        type_file.write(self.render_type_file(self.header_guard(type_file.name), mode))
2977
2978        if cmake_file:
2979            print("Writing to " + cmake_file.name)
2980            cmake_file.write(self.render_cmake_file(
2981                Path(cmake_file.name).stem, h_files, c_files, type_file,
2982                output_c_dir, output_h_dir, Path(cmake_file.name).absolute().parent))
2983
2984
2985def int_or_str(arg):
2986    try:
2987        return int(arg)
2988    except ValueError:
2989        # print(arg)
2990        if getrp(r"\A\w+\Z").match(arg) is not None:
2991            return arg
2992    raise ArgumentTypeError(
2993        "Argument must be an integer or a string with only letters, numbers, or '_'.")
2994
2995
2996def parse_args():
2997
2998    parent_parser = ArgumentParser(add_help=False)
2999
3000    parent_parser.add_argument(
3001        "-c", "--cddl", required=True, type=FileType('r', encoding='utf-8'), action="append",
3002        help="""Path to one or more input CDDL file(s). Passing multiple files is equivalent to
3003concatenating them.""")
3004    parent_parser.add_argument(
3005        "--no-prelude", required=False, action="store_true", default=False,
3006        help=f"""Exclude the standard CDDL prelude from the build. The prelude can be viewed at
3007{PRELUDE_PATH.relative_to(PACKAGE_PATH)} in the repo, or together with the script.""")
3008    parent_parser.add_argument(
3009        "-v", "--verbose", required=False, action="store_true", default=False,
3010        help="Print more information while parsing CDDL and generating code.")
3011
3012    parser = ArgumentParser(
3013        description='''Parse a CDDL file and validate/convert between YAML, JSON, and CBOR.
3014Can also generate C code for validation/encoding/decoding of CBOR.''')
3015
3016    parser.add_argument(
3017        "--version", action="version", version=f"zcbor {__version__}")
3018
3019    subparsers = parser.add_subparsers()
3020    code_parser = subparsers.add_parser(
3021        "code", description='''Parse a CDDL file and produce C code that validates and xcodes CBOR.
3022The output from this script is a C file and a header file. The header file
3023contains typedefs for all the types specified in the cddl input file, as well
3024as declarations to xcode functions for the types designated as entry types when
3025running the script. The c file contains all the code for decoding and validating
3026the types in the CDDL input file. All types are validated as they are xcoded.
3027
3028Where a `bstr .cbor <Type>` is specified in the CDDL, AND the Type is an entry
3029type, the xcoder will not xcode the string, only provide a pointer into the
3030payload buffer. This is useful to reduce the size of typedefs, or to break up
3031decoding. Using this mechanism is necessary when the CDDL contains self-
3032referencing types, since the C type cannot be self referencing.
3033
3034This script requires 'regex' for lookaround functionality not present in 're'.''',
3035        formatter_class=RawDescriptionHelpFormatter,
3036        parents=[parent_parser])
3037
3038    code_parser.add_argument(
3039        "--default-max-qty", "--dq", required=False, type=int_or_str, default=3,
3040        help="""Default maximum number of repetitions when no maximum
3041is specified. This is needed to construct complete C types.
3042
3043The default_max_qty can usually be set to a text symbol if desired,
3044to allow it to be configurable when building the code. This is not always
3045possible, as sometimes the value is needed for internal computations.
3046If so, the script will raise an exception.""")
3047    code_parser.add_argument(
3048        "--output-c", "--oc", required=False, type=str,
3049        help="""Path to output C file. If both --decode and --encode are specified, _decode and
3050_encode will be appended to the filename when creating the two files. If not
3051specified, the path and name will be based on the --output-cmake file. A 'src'
3052directory will be created next to the cmake file, and the C file will be
3053placed there with the same name (except the file extension) as the cmake file.""")
3054    code_parser.add_argument(
3055        "--output-h", "--oh", required=False, type=str,
3056        help="""Path to output header file. If both --decode and --encode are specified, _decode and
3057_encode will be appended to the filename when creating the two files. If not
3058specified, the path and name will be based on the --output-cmake file. An 'include'
3059directory will be created next to the cmake file, and the C file will be
3060placed there with the same name (except the file extension) as the cmake file.""")
3061    code_parser.add_argument(
3062        "--output-h-types", "--oht", required=False, type=str,
3063        help="""Path to output header file with typedefs (shared between decode and encode).
3064If not specified, the path and name will be taken from the output header file
3065(--output-h), with '_types' added to the file name.""")
3066    code_parser.add_argument(
3067        "--copy-sources", required=False, action="store_true", default=False,
3068        help="""Copy the non-generated source files (zcbor_*.c/h) into the same directories as the
3069generated files.""")
3070    code_parser.add_argument(
3071        "--output-cmake", required=False, type=str,
3072        help="""Path to output CMake file. The filename of the CMake file without '.cmake' is used
3073as the name of the CMake target in the file.
3074The CMake file defines a CMake target with the zcbor source files and the
3075generated file as sources, and the zcbor header files' and generated header
3076files' folders as include_directories.
3077Add it to your project via include() in your CMakeLists.txt file, and link the
3078target to your program.
3079This option works with or without the --copy-sources option.""")
3080    code_parser.add_argument(
3081        "-t", "--entry-types", required=True, type=str, nargs="+",
3082        help="Names of the types which should have their xcode functions exposed.")
3083    code_parser.add_argument(
3084        "-d", "--decode", required=False, action="store_true", default=False,
3085        help="Generate decoding code. Either --decode or --encode or both must be specified.")
3086    code_parser.add_argument(
3087        "-e", "--encode", required=False, action="store_true", default=False,
3088        help="Generate encoding code. Either --decode or --encode or both must be specified.")
3089    code_parser.add_argument(
3090        "--time-header", required=False, action="store_true", default=False,
3091        help="Put the current time in a comment in the generated files.")
3092    code_parser.add_argument(
3093        "--git-sha-header", required=False, action="store_true", default=False,
3094        help="Put the current git sha of zcbor in a comment in the generated files.")
3095    code_parser.add_argument(
3096        "-b", "--default-bit-size", required=False, type=int, default=32, choices=[32, 64],
3097        help="""Default bit size of integers in code. When integers have no explicit bounds,
3098assume they have this bit width. Should follow the bit width of the architecture
3099the code will be running on.""")
3100    code_parser.add_argument(
3101        "--include-prefix", default="",
3102        help="""When #include'ing generated files, add this path prefix to the filename.""")
3103    code_parser.add_argument(
3104        "-s", "--short-names", required=False, action="store_true", default=False,
3105        help="""Attempt to make most generated struct member names shorter. This might make some
3106names identical which will cause a compile error. If so, tweak the CDDL labels
3107or layout, or disable this option. This might also make enum names different
3108from the corresponding union members.""")
3109    code_parser.add_argument(
3110        "--file-header", required=False, type=str, default="",
3111        help="""Header to be included in the comment at the top of generated files, e.g. copyright.
3112Can be a string or a path to a file. If interpreted as a path to an existing file,
3113the file's contents will be used.""")
3114    code_parser.set_defaults(process=process_code)
3115
3116    validate_parent_parser = ArgumentParser(add_help=False)
3117    validate_parent_parser.add_argument(
3118        "-i", "--input", required=True, type=str,
3119        help='''Input data file. The option --input-as specifies how to interpret the contents.
3120Use "-" to indicate stdin.''')
3121    validate_parent_parser.add_argument(
3122        "--input-as", required=False, choices=["yaml", "json", "cbor", "cborhex"],
3123        help='''Which format to interpret the input file as.
3124If omitted, the format is inferred from the file name.
3125.yaml, .yml => YAML, .json => JSON, .cborhex => CBOR as hex string, everything else => CBOR''')
3126    validate_parent_parser.add_argument(
3127        "-t", "--entry-type", required=True, type=str,
3128        help='''Name of the type (from the CDDL) to interpret the data as.''')
3129    validate_parent_parser.add_argument(
3130        "--default-max-qty", "--dq", required=False, type=int, default=0xFFFFFFFF,
3131        help="""Default maximum number of repetitions when no maximum is specified.
3132It is only relevant when handling data that will be decoded by generated code.
3133If omitted, a large number will be used.""")
3134    validate_parent_parser.add_argument(
3135        "--yaml-compatibility", required=False, action="store_true", default=False,
3136        help='''Whether to convert CBOR-only values to YAML-compatible ones
3137(when converting from CBOR), or vice versa (when converting to CBOR).
3138
3139When this is enabled, all CBOR data is guaranteed to convert into YAML/JSON.
3140JSON and YAML do not support all data types that CBOR/CDDL supports.
3141bytestrings (BSTR), tags, undefined, and maps with non-text keys need
3142special handling. See the zcbor README for more information.''')
3143
3144    validate_parser = subparsers.add_parser(
3145        "validate", description='''Read CBOR, YAML, or JSON data from file or stdin and validate
3146it against a CDDL schema file.
3147        ''',
3148        parents=[parent_parser, validate_parent_parser])
3149
3150    validate_parser.set_defaults(process=process_validate)
3151
3152    convert_parser = subparsers.add_parser(
3153        "convert", description='''Parse a CDDL file and validate/convert between CBOR and YAML/JSON.
3154The script decodes the CBOR/YAML/JSON data from a file or stdin
3155and verifies that it conforms to the CDDL description.
3156The script fails if the data does not conform.
3157'zcbor validate' can be used if only validate is needed.''',
3158        parents=[parent_parser, validate_parent_parser])
3159
3160    convert_parser.add_argument(
3161        "-o", "--output", required=True, type=str,
3162        help='''Output data file. The option --output-as specifies how to interpret the contents.
3163 Use "-" to indicate stdout.''')
3164    convert_parser.add_argument(
3165        "--output-as", required=False, choices=["yaml", "json", "cbor", "cborhex", "c_code"],
3166        help='''Which format to interpret the output file as.
3167If omitted, the format is inferred from the file name.
3168.yaml, .yml => YAML, .json => JSON, .c, .h => C code,
3169.cborhex => CBOR as hex string, everything else => CBOR''')
3170    convert_parser.add_argument(
3171        "--c-code-var-name", required=False, type=str,
3172        help='''Only relevant together with '--output-as c_code' or .c files.''')
3173    convert_parser.add_argument(
3174        "--c-code-columns", required=False, type=int, default=0,
3175        help='''Only relevant together with '--output-as c_code' or .c files.
3176The number of bytes per line in the variable instantiation. If omitted, the
3177entire declaration is a single line.''')
3178    convert_parser.set_defaults(process=process_convert)
3179
3180    args = parser.parse_args()
3181
3182    if not args.no_prelude:
3183        args.cddl.append(open(PRELUDE_PATH, 'r', encoding="utf-8"))
3184
3185    if hasattr(args, "decode") and not args.decode and not args.encode:
3186        parser.error("Please specify at least one of --decode or --encode.")
3187
3188    if hasattr(args, "output_c"):
3189        if not args.output_c or not args.output_h:
3190            if not args.output_cmake:
3191                parser.error(
3192                    "Please specify both --output-c and --output-h "
3193                    "unless --output-cmake is specified.")
3194
3195    return args
3196
3197
3198def process_code(args):
3199    modes = list()
3200    if args.decode:
3201        modes.append("decode")
3202    if args.encode:
3203        modes.append("encode")
3204
3205    if args.file_header and Path(args.file_header).exists():
3206        args.file_header = Path(args.file_header).read_text(encoding="utf-8")
3207
3208    print("Parsing files: " + ", ".join((c.name for c in args.cddl)))
3209
3210    cddl_contents = linesep.join((c.read() for c in args.cddl))
3211
3212    cddl_res = dict()
3213    for mode in modes:
3214        cddl_res[mode] = CodeGenerator.from_cddl(
3215            mode, cddl_contents, args.default_max_qty, mode, args.entry_types,
3216            args.default_bit_size, short_names=args.short_names)
3217
3218    # Parsing is done, pretty print the result.
3219    verbose_print(args.verbose, "Parsed CDDL types:")
3220    for mode in modes:
3221        verbose_pprint(args.verbose, cddl_res[mode].my_types)
3222
3223    git_sha = ''
3224    if args.git_sha_header:
3225        if "zcbor.py" in sys.argv[0]:
3226            git_args = ['git', 'rev-parse', '--verify', '--short', 'HEAD']
3227            git_sha = Popen(
3228                git_args, cwd=PACKAGE_PATH, stdout=PIPE).communicate()[0].decode('utf-8').strip()
3229        else:
3230            git_sha = __version__
3231
3232    def create_and_open(path):
3233        Path(path).absolute().parent.mkdir(parents=True, exist_ok=True)
3234        return Path(path).open('w', encoding='utf-8')
3235
3236    if args.output_cmake:
3237        cmake_dir = Path(args.output_cmake).parent
3238        output_cmake = create_and_open(args.output_cmake)
3239        filenames = Path(args.output_cmake).parts[-1].replace(".cmake", "")
3240    else:
3241        output_cmake = None
3242
3243    def add_mode_to_fname(filename, mode):
3244        name = Path(filename).stem + "_" + mode + Path(filename).suffix
3245        return Path(filename).with_name(name)
3246
3247    output_c = dict()
3248    output_h = dict()
3249    out_c = args.output_c if (len(modes) == 1 and args.output_c) else None
3250    out_h = args.output_h if (len(modes) == 1 and args.output_h) else None
3251    for mode in modes:
3252        output_c[mode] = create_and_open(
3253            out_c or add_mode_to_fname(
3254                args.output_c or Path(cmake_dir, 'src', f'{filenames}.c'), mode))
3255        output_h[mode] = create_and_open(
3256            out_h or add_mode_to_fname(
3257                args.output_h or Path(cmake_dir, 'include', f'{filenames}.h'), mode))
3258
3259    out_c_parent = Path(output_c[modes[0]].name).parent
3260    out_h_parent = Path(output_h[modes[0]].name).parent
3261
3262    output_h_types = create_and_open(
3263        args.output_h_types
3264        or (args.output_h and Path(args.output_h).with_name(Path(args.output_h).stem + "_types.h"))
3265        or Path(cmake_dir, 'include', filenames + '_types.h'))
3266
3267    renderer = CodeRenderer(entry_types={mode: [cddl_res[mode].my_types[entry]
3268                                         for entry in args.entry_types] for mode in modes},
3269                            modes=modes, print_time=args.time_header,
3270                            default_max_qty=args.default_max_qty, git_sha=git_sha,
3271                            file_header=args.file_header
3272                            )
3273
3274    c_code_dir = C_SRC_PATH
3275    h_code_dir = C_INCLUDE_PATH
3276
3277    if args.copy_sources:
3278        new_c_code_dir = out_c_parent
3279        new_h_code_dir = out_h_parent
3280        copyfile(Path(c_code_dir, "zcbor_decode.c"), Path(new_c_code_dir, "zcbor_decode.c"))
3281        copyfile(Path(c_code_dir, "zcbor_encode.c"), Path(new_c_code_dir, "zcbor_encode.c"))
3282        copyfile(Path(c_code_dir, "zcbor_common.c"), Path(new_c_code_dir, "zcbor_common.c"))
3283        copyfile(Path(c_code_dir, "zcbor_print.c"), Path(new_c_code_dir, "zcbor_print.c"))
3284        copyfile(Path(h_code_dir, "zcbor_decode.h"), Path(new_h_code_dir, "zcbor_decode.h"))
3285        copyfile(Path(h_code_dir, "zcbor_encode.h"), Path(new_h_code_dir, "zcbor_encode.h"))
3286        copyfile(Path(h_code_dir, "zcbor_common.h"), Path(new_h_code_dir, "zcbor_common.h"))
3287        copyfile(Path(h_code_dir, "zcbor_tags.h"), Path(new_h_code_dir, "zcbor_tags.h"))
3288        copyfile(Path(h_code_dir, "zcbor_print.h"), Path(new_h_code_dir, "zcbor_print.h"))
3289        c_code_dir = new_c_code_dir
3290        h_code_dir = new_h_code_dir
3291
3292    renderer.render(modes, output_h, output_c, output_h_types, args.include_prefix,
3293                    output_cmake, c_code_dir, h_code_dir)
3294
3295
3296def parse_cddl(args):
3297    cddl_contents = linesep.join((c.read() for c in args.cddl))
3298    cddl_res = DataTranslator.from_cddl(cddl_contents, args.default_max_qty)
3299    return cddl_res.my_types[args.entry_type]
3300
3301
3302def read_data(args, cddl):
3303    _, in_file_ext = path.splitext(args.input)
3304    in_file_format = args.input_as or in_file_ext.strip(".")
3305    if in_file_format in ["yaml", "yml"]:
3306        f = sys.stdin if args.input == "-" else open(args.input, "r", encoding="utf-8")
3307        cbor_str = cddl.from_yaml(f.read(), yaml_compat=args.yaml_compatibility)
3308    elif in_file_format == "json":
3309        f = sys.stdin if args.input == "-" else open(args.input, "r", encoding="utf-8")
3310        cbor_str = cddl.from_json(f.read(), yaml_compat=args.yaml_compatibility)
3311    elif in_file_format == "cborhex":
3312        f = sys.stdin if args.input == "-" else open(args.input, "r", encoding="utf-8")
3313        cbor_str = bytes.fromhex(f.read().replace("\n", ""))
3314        cddl.validate_str(cbor_str)
3315    else:
3316        f = sys.stdin.buffer if args.input == "-" else open(args.input, "rb", encoding="utf-8")
3317        cbor_str = f.read()
3318        cddl.validate_str(cbor_str)
3319
3320    return cbor_str
3321
3322
3323def write_data(args, cddl, cbor_str):
3324    _, out_file_ext = path.splitext(args.output)
3325    out_file_format = args.output_as or out_file_ext.strip(".")
3326    if out_file_format in ["yaml", "yml"]:
3327        f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8")
3328        f.write(cddl.str_to_yaml(cbor_str, yaml_compat=args.yaml_compatibility))
3329    elif out_file_format == "json":
3330        f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8")
3331        f.write(cddl.str_to_json(cbor_str, yaml_compat=args.yaml_compatibility))
3332    elif out_file_format in ["c", "h", "c_code"]:
3333        f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8")
3334        assert args.c_code_var_name is not None, \
3335            "Must specify --c-code-var-name when outputting c code."
3336        f.write(cddl.str_to_c_code(cbor_str, args.c_code_var_name, args.c_code_columns))
3337    elif out_file_format == "cborhex":
3338        f = sys.stdout if args.output == "-" else open(args.output, "w", encoding="utf-8")
3339        f.write(getrp(r"(.{1,64})").sub(r"\1\n", cbor_str.hex()))  # Add newlines every 64 chars
3340    else:
3341        f = sys.stdout.buffer if args.output == "-" else open(args.output, "wb")
3342        f.write(cbor_str)
3343
3344
3345def process_validate(args):
3346    cddl = parse_cddl(args)
3347    read_data(args, cddl)
3348
3349
3350def process_convert(args):
3351    cddl = parse_cddl(args)
3352    cbor_str = read_data(args, cddl)
3353    write_data(args, cddl, cbor_str)
3354
3355
3356def main():
3357    args = parse_args()
3358    args.process(args)
3359
3360
3361if __name__ == "__main__":
3362    main()
3363