1#!/usr/bin/env python3
2#
3# Copyright (c) 2020 Nordic Semiconductor ASA
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7
8from regex import compile, S, M
9from pprint import pformat, pprint
10from os import path, linesep, makedirs
11from collections import defaultdict, namedtuple
12from collections.abc import Hashable
13from typing import NamedTuple
14from argparse import ArgumentParser, ArgumentTypeError, RawDescriptionHelpFormatter, FileType
15from datetime import datetime
16from copy import copy
17from itertools import tee
18from cbor2 import (loads, dumps, CBORTag, load, CBORDecodeValueError, CBORDecodeEOF, undefined,
19                   CBORSimpleValue)
20from yaml import safe_load as yaml_load, dump as yaml_dump
21from json import loads as json_load, dumps as json_dump
22from io import BytesIO
23from subprocess import Popen, PIPE
24from pathlib import Path, PurePath, PurePosixPath
25from shutil import copyfile
26import sys
27from site import USER_BASE
28from textwrap import wrap, indent
29from importlib.metadata import version
30
31regex_cache = {}
32indentation = "\t"
33newl_ind = "\n" + indentation
34
35SCRIPT_PATH = Path(__file__).absolute().parent
36PACKAGE_PATH = Path(__file__).absolute().parents[1]
37PRELUDE_PATH = SCRIPT_PATH / "prelude.cddl"
38VERSION_PATH = SCRIPT_PATH / "VERSION"
39C_SRC_PATH = PACKAGE_PATH / "src"
40C_INCLUDE_PATH = PACKAGE_PATH / "include"
41
42__version__ = VERSION_PATH.read_text(encoding="utf-8").strip()
43
44UINT8_MAX = 0xFF
45UINT16_MAX = 0xFFFF
46UINT32_MAX = 0xFFFFFFFF
47UINT64_MAX = 0xFFFFFFFFFFFFFFFF
48
49INT8_MAX = 0x7F
50INT16_MAX = 0x7FFF
51INT32_MAX = 0x7FFFFFFF
52INT64_MAX = 0x7FFFFFFFFFFFFFFF
53
54INT8_MIN = -0x80
55INT16_MIN = -0x8000
56INT32_MIN = -0x80000000
57INT64_MIN = -0x8000000000000000
58
59
60def getrp(pattern, flags=0):
61    pattern_key = pattern if not flags else (pattern, flags)
62    if pattern_key not in regex_cache:
63        regex_cache[pattern_key] = compile(pattern, flags)
64    return regex_cache[pattern_key]
65
66
67# Size of "additional" field if num is encoded as int
68def sizeof(num):
69    if num <= 23:
70        return 0
71    elif num <= UINT8_MAX:
72        return 1
73    elif num <= UINT16_MAX:
74        return 2
75    elif num <= UINT32_MAX:
76        return 4
77    elif num <= UINT64_MAX:
78        return 8
79    else:
80        raise ValueError("Number too large (more than 64 bits).")
81
82
83# Print only if verbose
84def verbose_print(verbose_flag, *things):
85    if verbose_flag:
86        print(*things)
87
88
89# Pretty print only if verbose
90def verbose_pprint(verbose_flag, *things):
91    if verbose_flag:
92        pprint(*things)
93
94
95global_counter = 0
96
97
98# Retrieve a unique id.
99def counter(reset=False):
100    global global_counter
101    if reset:
102        global_counter = 0
103        return global_counter
104    global_counter += 1
105    return global_counter
106
107
108# Replace an element in a list or tuple and return the list. For use in
109# lambdas.
110def list_replace_if_not_null(lst, i, r):
111    if lst[i] == "NULL":
112        return lst
113    if isinstance(lst, tuple):
114        convert = tuple
115        lst = list(lst)
116    else:
117        assert isinstance(lst, list)
118        convert = list
119    lst[i] = r
120    return convert(lst)
121
122
123# Return a code snippet that assigns the value to a variable var_name and
124# returns pointer to the variable, or returns NULL if the value is None.
125def val_or_null(value, var_name):
126    return "(%s = %d, &%s)" % (var_name, value, var_name) if value is not None else "NULL"
127
128
129# Assign the min_value variable.
130def tmp_str_or_null(value):
131    value_str = f'"{value}"' if value is not None else 'NULL'
132    len_str = f"""sizeof({f'"{value}"'}) - 1, &tmp_str)"""
133    return f"(tmp_str.value = (uint8_t *){value_str}, tmp_str.len = {len_str}"
134
135
136# Assign the max_value variable.
137def min_bool_or_null(value):
138    return f"(&(bool){{{int(value)}}})"
139
140
141def deref_if_not_null(access):
142    return access if access == "NULL" else "&" + access
143
144
145# Return an argument list for a function call to a encoder/decoder function.
146def xcode_args(res, *sargs):
147    if len(sargs) > 0:
148        return "state, %s, %s, %s" % (
149            "&(%s)" % res if res != "NULL" else res, sargs[0], sargs[1])
150    else:
151        return "state, %s" % (
152            "(%s)" % res if res != "NULL" else res)
153
154
155# Return the code that calls a encoder/decoder function with a given set of
156# arguments.
157def xcode_statement(func, *sargs, **kwargs):
158    if func is None:
159        return "1"
160    return "(%s(%s))" % (func, xcode_args(*sargs, **kwargs))
161
162
163def add_semicolon(decl):
164    if len(decl) != 0 and decl[-1][-1] != ";":
165        decl[-1] += ";"
166    return decl
167
168
169def struct_ptr_name(mode):
170    return "result" if mode == "decode" else "input"
171
172
173def ternary_if_chain(access, names, xcode_strings):
174    return "((%s == %s) ? %s%s: %s)" % (
175        access,
176        names[0],
177        xcode_strings[0],
178        newl_ind,
179        ternary_if_chain(access, names[1:], xcode_strings[1:]) if len(names) > 1 else "false")
180
181
182val_conversions = {
183    (2**64) - 1: "UINT64_MAX",
184    (2**63) - 1: "INT64_MAX",
185    (2**32) - 1: "UINT32_MAX",
186    (2**31) - 1: "INT32_MAX",
187    (2**16) - 1: "UINT16_MAX",
188    (2**15) - 1: "INT16_MAX",
189    (2**8) - 1: "UINT8_MAX",
190    (2**7) - 1: "INT8_MAX",
191    -(2**63): "INT64_MIN",
192    -(2**31): "INT32_MIN",
193    -(2**15): "INT16_MIN",
194    -(2**7): "INT8_MIN",
195}
196
197
198def val_to_str(val):
199    if isinstance(val, bool):
200        return str(val).lower()
201    elif isinstance(val, Hashable) and val in val_conversions:
202        return val_conversions[val]
203    return str(val)
204
205
206# Class for parsing CDDL. One instance represents one CBOR data item, with a few caveats:
207#  - For repeated data, one instance represents all repetitions.
208#  - For "OTHER" types, one instance points to another type definition.
209#  - For "GROUP" and "UNION" types, there is no separate data item for the instance.
210class CddlParser:
211    def __init__(self, default_max_qty, my_types, my_control_groups, base_name=None,
212                 short_names=False, base_stem=''):
213        self.id_prefix = "temp_" + str(counter())
214        self.id_num = None  # Unique ID number. Only populated if needed.
215        # The value of the data item. Has different meaning for different
216        # types.
217        self.value = None
218        self.max_value = None  # Maximum value. Only used for numbers and bools.
219        self.min_value = None  # Minimum value. Only used for numbers and bools.
220        # The readable label associated with the element in the CDDL.
221        self.label = None
222        self.min_qty = 1  # The minimum number of times this element is repeated.
223        self.max_qty = 1  # The maximum number of times this element is repeated.
224        # The size of the element. Only used for integers, byte strings, and
225        # text strings.
226        self.size = None
227        self.min_size = None  # Minimum size.
228        self.max_size = None  # Maximum size.
229        # Key element. Only for children of "MAP" elements. self.key is of the
230        # same class as self.
231        self.key = None
232        # The element specified via.cbor or.cborseq(only for byte
233        # strings).self.cbor is of the same class as self.
234        self.cbor = None
235        # Any tags (type 6) to precede the element.
236        self.tags = []
237        # The CDDL string used to determine the min_qty and max_qty. Not used after
238        # min_qty and max_qty are determined.
239        self.quantifier = None
240        # Sockets are types starting with "$" or "$$". Do not fail if these aren't defined.
241        self.is_socket = False
242        # If the type has a ".bits <group_name>", this will contain <group_name> which can be looked
243        # up in my_control_groups.
244        self.bits = None
245        # The "type" of the element. This follows the CBOR types loosely, but are more related to
246        # CDDL concepts. The possible types are "INT", "UINT", "NINT", "FLOAT", "BSTR", "TSTR",
247        # "BOOL", "NIL", "UNDEF", "LIST", "MAP","GROUP", "UNION" and "OTHER". "OTHER" represents a
248        # CDDL type defined with '='.
249        self.type = None
250        self.match_str = ""
251        self.errors = list()
252
253        self.my_types = my_types
254        self.my_control_groups = my_control_groups
255        self.default_max_qty = default_max_qty  # args.default_max_qty
256        self.base_name = base_name  # Used as default for self.get_base_name()
257        # Stem which can be used when generating an id.
258        self.base_stem = base_stem.replace("-", "_")
259        self.short_names = short_names
260
261        if type(self) not in type(self).cddl_regexes:
262            self.cddl_regexes_init()
263
264    @classmethod
265    def from_cddl(cddl_class, cddl_string, default_max_qty, *args, **kwargs):
266        my_types = dict()
267
268        type_strings = cddl_class.get_types(cddl_string)
269        # Separate type_strings as keys in two dicts, one dict for strings that start with &( which
270        # are special control operators for .bits, and one dict for all the regular types.
271        my_types = \
272            {my_type: None for my_type, val in type_strings.items() if not val.startswith("&(")}
273        my_control_groups = \
274            {my_cg: None for my_cg, val in type_strings.items() if val.startswith("&(")}
275
276        # Parse the definitions, replacing the each string with a
277        # CodeGenerator instance.
278        for my_type, cddl_string in type_strings.items():
279            parsed = cddl_class(*args, default_max_qty, my_types, my_control_groups, **kwargs,
280                                base_stem=my_type)
281            parsed.get_value(cddl_string.replace("\n", " ").lstrip("&"))
282            parsed = parsed.flatten()[0]
283            if my_type in my_types:
284                my_types[my_type] = parsed
285            elif my_type in my_control_groups:
286                my_control_groups[my_type] = parsed
287
288        counter(True)
289
290        # post_validate all the definitions.
291        for my_type in my_types:
292            my_types[my_type].set_id_prefix()
293            my_types[my_type].post_validate()
294            my_types[my_type].set_base_names()
295        for my_control_group in my_control_groups:
296            my_control_groups[my_control_group].set_id_prefix()
297            my_control_groups[my_control_group].post_validate_control_group()
298
299        return CddlTypes(my_types, my_control_groups)
300
301    # Strip CDDL comments (';') from the string.
302    @staticmethod
303    def strip_comments(instr):
304        return getrp(r"\;.*?\n").sub('', instr)
305
306    @staticmethod
307    def resolve_backslashes(instr):
308        return getrp(r"\\\n").sub(" ", instr)
309
310    # Returns a dict containing multiple typename=>string
311    @classmethod
312    def get_types(cls, cddl_string):
313        instr = cls.strip_comments(cddl_string)
314        instr = cls.resolve_backslashes(instr)
315        type_regex = \
316            r"(\s*?\$?\$?([\w-]+)\s*(\/{0,2})=\s*(.*?)(?=(\Z|\s*\$?\$?[\w-]+\s*\/{0,2}=(?!\>))))"
317        result = defaultdict(lambda: "")
318        types = [
319            (key, value, slashes)
320            for (_1, key, slashes, value, _2) in getrp(type_regex, S | M).findall(instr)]
321        for key, value, slashes in types:
322            if slashes:
323                result[key] += slashes
324                result[key] += value
325                result[key] = result[key].lstrip(slashes)  # strip from front
326            else:
327                if key in result:
328                    raise ValueError(f"Duplicate CDDL type found: {key}")
329                result[key] = value
330        return dict(result)
331
332    backslash_quotation_mark = r'\"'
333
334    # Generate a (hopefully) unique and descriptive name
335    def generate_base_name(self):
336        # The first non-None entry is used:
337        raw_name = ((
338            # The label is the default if present:
339            self.label
340            # Name a key/value pair by its key type or string value:
341            or (self.key.value if self.key and self.key.type in ["TSTR", "OTHER"] else None)
342            # Name a string by its expected value:
343            or (f"{self.value.replace(self.backslash_quotation_mark, '')}_{self.type.lower()}"
344                if self.type == "TSTR" and self.value is not None else None)
345            # Name an integer by its expected value:
346            or (f"{self.type.lower()}{self.value}"
347                if self.type in ["INT", "UINT"] and self.value is not None else None)
348            # Name a type by its type name
349            or (next((key for key, value in self.my_types.items() if value == self), None))
350            # Name a control group by its name
351            or (next((key for key, value in self.my_control_groups.items() if value == self), None))
352            # Name an instance by its type:
353            or (self.value + "_m" if self.type == "OTHER" else None)
354            # Name a list by its first element:
355            or (self.value[0].get_base_name() + "_l"
356                if self.type in ["LIST", "GROUP"] and self.value else None)
357            # Name a cbor-encoded bstr by its expected cbor contents:
358            or ((self.cbor.value + "_bstr")
359                if self.cbor and self.cbor.type in ["TSTR", "OTHER"] else None)
360            # Name a key value pair by its key (regardless of the key type)
361            or ((self.key.generate_base_name() + self.type.lower()) if self.key else None)
362            # Name an element by its minimum/maximum "size" (if the min == the max)
363            or (f"{self.type.lower()}{self.min_size * 8}"
364                if self.min_size and self.min_size == self.max_size else None)
365            # Name an element by its minimum/maximum "size" (if the min != the max)
366            or (f"{self.type.lower()}{self.min_size * 8}-{self.max_size * 8}"
367                if self.min_size and self.max_size else None)
368            # Name an element by its type.
369            or self.type.lower()).replace("-", "_"))
370
371        # Make the name compatible with C variable names
372        # (don't start with a digit, don't use accented letters or symbols other than '_')
373        name_regex = getrp(r'[a-zA-Z_][a-zA-Z\d_]*')
374        if name_regex.fullmatch(raw_name) is None:
375            latinized_name = getrp(r'[^a-zA-Z\d_]').sub("", raw_name)
376            if name_regex.fullmatch(latinized_name) is None:
377                # Add '_' if name starts with a digit or is empty after removing accented chars.
378                latinized_name = "_" + latinized_name
379            assert name_regex.fullmatch(latinized_name) is not None, \
380                f"Couldn't make '{raw_name}' valid. '{latinized_name}' is invalid."
381            return latinized_name
382        return raw_name
383
384    # Base name used for functions, variables, and typedefs.
385    def get_base_name(self):
386        if not self.base_name:
387            self.set_base_name(self.generate_base_name())
388        return self.base_name
389
390    # Set an explicit base name for this element.
391    def set_base_name(self, base_name):
392        self.base_name = base_name.replace("-", "_")
393
394    def set_base_names(self):
395        if self.cbor:
396            self.cbor.set_base_name(self.var_name().strip("_") + "_cbor")
397        if self.key:
398            self.key.set_base_name(self.var_name().strip("_") + "_key")
399
400        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
401            for child in self.value:
402                child.set_base_names()
403        if self.cbor:
404            self.cbor.set_base_names()
405        if self.key:
406            self.key.set_base_names()
407
408    # Add uniqueness to the base name.
409    def id(self, with_prefix=True):
410        raw_name = self.get_base_name()
411        if not with_prefix and self.short_names:
412            return raw_name
413        if (self.id_prefix
414                and (f"{self.id_prefix}_" not in raw_name)
415                and (self.id_prefix != raw_name.strip("_"))):
416            return f"{self.id_prefix}_{raw_name}"
417        if (self.base_stem
418                and (f"{self.base_stem}_" not in raw_name)
419                and (self.base_stem != raw_name.strip("_"))):
420            return f"{self.base_stem}_{raw_name}"
421        return raw_name
422
423    def init_args(self):
424        return (self.default_max_qty,)
425
426    def init_kwargs(self):
427        return {
428            "my_types": self.my_types, "my_control_groups": self.my_control_groups,
429            "short_names": self.short_names}
430
431    def set_id_prefix(self, id_prefix=''):
432        self.id_prefix = id_prefix
433        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
434            for child in self.value:
435                if child.single_func_impl_condition():
436                    child.set_id_prefix(self.generate_base_name())
437                else:
438                    child.set_id_prefix(self.child_base_id())
439        if self.cbor:
440            self.cbor.set_id_prefix(self.child_base_id())
441        if self.key:
442            self.key.set_id_prefix(self.child_base_id())
443
444    # Id to pass to children for them to use as basis for their id/base name.
445    def child_base_id(self):
446        return self.id()
447
448    def get_id_num(self):
449        if self.id_num is None:
450            self.id_num = counter()
451        return self.id_num
452
453    # Human readable representation.
454    def mrepr(self, newline):
455        reprstr = ''
456        if self.quantifier:
457            reprstr += self.quantifier
458        if self.label:
459            reprstr += self.label + ':'
460        for tag in self.tags:
461            reprstr += f"#6.{tag}"
462        if self.key:
463            reprstr += repr(self.key) + " => "
464        if self.is_unambiguous():
465            reprstr += '/'
466        if self.is_unambiguous_repeated():
467            reprstr += '/'
468        reprstr += self.type
469        if self.size:
470            reprstr += '(%d)' % self.size
471        if newline:
472            reprstr += '\n'
473        if self.value:
474            reprstr += pformat(self.value, indent=4, width=1)
475        if self.cbor:
476            reprstr += " cbor: " + repr(self.cbor)
477        return reprstr.replace('\n', '\n    ')
478
479    def _flatten(self):
480        new_value = []
481        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
482            for child in self.value:
483                new_value.extend(
484                    child.flatten(allow_multi=self.type != "UNION"))
485            self.value = new_value
486        if self.key:
487            self.key = self.key.flatten()[0]
488        if self.cbor:
489            self.cbor = self.cbor.flatten()[0]
490
491    def flatten(self, allow_multi=False):
492        self._flatten()
493        if self.type == "OTHER" and self.is_socket and self.value not in self.my_types:
494            return []
495        if self.type in ["GROUP", "UNION"]\
496                and (len(self.value) == 1)\
497                and (not (self.key and self.value[0].key)):
498            self.value[0].min_qty *= self.min_qty
499            self.value[0].max_qty *= self.max_qty
500            if not self.value[0].label:
501                self.value[0].label = self.label
502            if not self.value[0].key:
503                self.value[0].key = self.key
504            self.value[0].tags.extend(self.tags)
505            return self.value
506        elif allow_multi and self.type in ["GROUP"] and self.min_qty == 1 and self.max_qty == 1:
507            return self.value
508        else:
509            return [self]
510
511    def set_min_value(self, min_value):
512        self.min_value = min_value
513
514    def set_max_value(self, max_value):
515        self.max_value = max_value
516
517    # Set the self.type and self.value of this element. For use during CDDL
518    # parsing.
519    def type_and_value(self, new_type, value_generator):
520        if self.type is not None:
521            raise TypeError(
522                "Cannot have two values: %s, %s" %
523                (self.type, new_type))
524        if new_type is None:
525            raise TypeError("Cannot set None as type")
526        if new_type == "UNION" and self.value is not None:
527            raise ValueError("Did not expect multiple parsed values for union")
528
529        self.type = new_type
530        self.set_value(value_generator)
531
532    def set_value(self, value_generator):
533        value = value_generator()
534        self.value = value
535
536        if self.type == "OTHER" and self.value.startswith("$"):
537            self.value = self.value.lstrip("$")
538            self.is_socket = True
539
540        if self.type in ["BSTR", "TSTR"]:
541            if value is not None:
542                self.set_size(len(value))
543        if self.type in ["UINT", "NINT"]:
544            if value is not None:
545                self.size = sizeof(value)
546                self.set_min_value(value)
547                self.set_max_value(value)
548        if self.type == "NINT":
549            self.max_value = -1
550
551    # Set the self.type and self.minValue and self.max_value (or self.min_size and self.max_size
552    # depending on the type) of this element. For use during CDDL parsing.
553    def type_and_range(self, new_type, min_val, max_val, inc_end=True):
554        if not inc_end:
555            max_val -= 1
556        if new_type not in ["INT", "UINT", "NINT"]:
557            raise TypeError(
558                "Only integers (not %s) can have range" %
559                (new_type,))
560        if min_val > max_val:
561            raise TypeError(
562                "Range has larger minimum than maximum (min %d, max %d)" %
563                (min_val, max_val))
564        if min_val == max_val:
565            return self.type_and_value(new_type, min_val)
566        self.type = new_type
567        self.set_min_value(min_val)
568        self.set_max_value(max_val)
569        if new_type in "UINT":
570            self.set_size_range(sizeof(min_val), sizeof(max_val))
571        if new_type == "NINT":
572            self.set_size_range(sizeof(abs(max_val)), sizeof(abs(min_val)))
573        if new_type == "INT":
574            self.set_size_range(None, max(sizeof(abs(max_val)), sizeof(abs(min_val))))
575
576    # Set the self.value and self.size of this element. For use during CDDL
577    # parsing.
578    def type_value_size(self, new_type, value, size):
579        self.type_and_value(new_type, value)
580        self.set_size(size)
581
582    # Set the self.value and self.min_size and self.max_size of this element.
583    # For use during CDDL parsing.
584    def type_value_size_range(self, new_type, value, min_size, max_size):
585        self.type_and_value(new_type, value)
586        self.set_size_range(min_size, max_size)
587
588    # Set the self.label of this element. For use during CDDL parsing.
589    def set_label(self, label):
590        if self.type is not None:
591            raise TypeError("Cannot have label after value: " + label)
592        self.label = label
593
594    # Set the self.quantifier, self.min_qty, and self.max_qty of this element. For
595    # use during CDDL parsing.
596    def set_quantifier(self, quantifier):
597        if self.type is not None:
598            raise TypeError(
599                "Cannot have quantifier after value: " + quantifier)
600
601        quantifier_mapping = [
602            (r"\?", lambda mo: (0, 1)),
603            (r"\*", lambda mo: (0, None)),
604            (r"\+", lambda mo: (1, None)),
605            (r"(.*?)\*\*?(.*)",
606                lambda mo: (int(mo.groups()[0] or "0", 0), int(mo.groups()[1] or "0", 0) or None)),
607        ]
608
609        self.quantifier = quantifier
610        for (reg, handler) in quantifier_mapping:
611            match_obj = getrp(reg).match(quantifier)
612            if match_obj:
613                (self.min_qty, self.max_qty) = handler(match_obj)
614                if self.max_qty is None:
615                    self.max_qty = self.default_max_qty
616                return
617        raise ValueError("invalid quantifier: %s" % quantifier)
618
619    # Set the self.size of this element. This will also set the self.minValue and self.max_value of
620    # UINT types.
621    # For use during CDDL parsing.
622    def set_size(self, size):
623        if self.type is None:
624            raise TypeError("Cannot have size before value: " + str(size))
625        elif self.type in ["INT", "UINT", "NINT"]:
626            value = 256**size
627            if self.type == "INT":
628                self.max_value = int((value >> 1) - 1)
629            if self.type == "UINT":
630                self.max_value = int(value - 1)
631            if self.type in ["INT", "NINT"]:
632                self.min_value = int(-1 * (value >> 1))
633        elif self.type in ["BSTR", "TSTR", "FLOAT"]:
634            self.set_size_range(size, size)
635        else:
636            raise TypeError(".size cannot be applied to %s" % self.type)
637
638    # Set the self.minValue and self.max_value or self.min_size and self.max_size of this element
639    # based on what values can be contained within an integer of a certain size. For use during
640    # CDDL parsing.
641    def set_size_range(self, min_size, max_size_in, inc_end=True):
642        max_size = max_size_in if inc_end else max_size_in - 1
643
644        if (min_size and min_size < 0 or max_size and max_size < 0) \
645           or (None not in [min_size, max_size] and min_size > max_size):
646            raise TypeError(
647                "Invalid size range (min %d, max %d)" %
648                (min_size, max_size))
649
650        self.set_min_size(min_size)
651        self.set_max_size(max_size)
652
653    # Set self.min_size, and self.minValue if type is UINT.
654    def set_min_size(self, min_size):
655        if self.type == "UINT":
656            self.minValue = 256**min(0, abs(min_size - 1)) if min_size else None
657        self.min_size = min_size if min_size else None
658
659    # Set self.max_size, and self.max_value if type is UINT.
660    def set_max_size(self, max_size):
661        if self.type == "UINT" and max_size and self.max_value is None:
662            if max_size > 8:
663                raise TypeError(
664                    "Size too large for integer. size %d" %
665                    max_size)
666            self.max_value = 256**max_size - 1
667        self.max_size = max_size
668
669    # Set the self.cbor of this element. For use during CDDL parsing.
670    def set_cbor(self, cbor, cborseq):
671        if self.type != "BSTR":
672            raise TypeError(
673                "%s must be used with bstr." %
674                (".cborseq" if cborseq else ".cbor",))
675        self.cbor = cbor
676        if cborseq:
677            self.cbor.max_qty = self.default_max_qty
678
679    def set_bits(self, bits):
680        if self.type != "UINT":
681            raise TypeError(".bits must be used with bstr.")
682        self.bits = bits
683
684    # Set the self.key of this element. For use during CDDL parsing.
685    def set_key(self, key):
686        if self.key is not None:
687            raise TypeError("Cannot have two keys: " + key)
688        if key.type == "GROUP":
689            raise TypeError("A key cannot be a group because it might represent more than 1 type.")
690        self.key = key
691
692    # Set the self.label OR self.key of this element. In the CDDL "foo: bar", foo can be either a
693    # label or a key depending on whether it is in a map. This code uses a slightly different method
694    # for choosing between label and key. If the string is recognized as a type, it is treated as a
695    # key. For use during CDDL parsing.
696    def set_key_or_label(self, key_or_label):
697        if key_or_label in self.my_types:
698            self.set_key(self.parse(key_or_label)[0])
699            assert self.key.type == "OTHER", "This should only be able to produce an OTHER key."
700            if self.label is None:
701                self.set_label(key_or_label)
702        else:
703            self.set_label(key_or_label)
704
705    def add_tag(self, tag):
706        self.tags.append(int(tag))
707
708    # Append to the self.value of this element. Used with the "UNION" type, which has
709    # a python list as self.value. The list represents the "children" of the
710    # type. For use during CDDL parsing.
711    def union_add_value(self, value, doubleslash=False):
712        if self.type != "UNION":
713            convert_val = copy(self)
714            self.__init__(*self.init_args(), **self.init_kwargs())
715            self.type_and_value("UNION", lambda: [convert_val])
716
717            self.base_name = convert_val.base_name
718            convert_val.base_name = None
719            self.base_stem = convert_val.base_stem
720
721            if not doubleslash:
722                self.label = convert_val.label
723                self.key = convert_val.key
724                self.quantifier = convert_val.quantifier
725                self.max_qty = convert_val.max_qty
726                self.min_qty = convert_val.min_qty
727
728                convert_val.label = None
729                convert_val.key = None
730                convert_val.quantifier = None
731                convert_val.max_qty = 1
732                convert_val.min_qty = 1
733        self.value.append(value)
734
735    def convert_to_key(self):
736        convert_val = copy(self)
737        self.__init__(*self.init_args(), **self.init_kwargs())
738        self.set_key(convert_val)
739
740        self.label = convert_val.label
741        self.quantifier = convert_val.quantifier
742        self.max_qty = convert_val.max_qty
743        self.min_qty = convert_val.min_qty
744        self.base_name = convert_val.base_name
745        self.base_stem = convert_val.base_stem
746
747        convert_val.label = None
748        convert_val.quantifier = None
749        convert_val.max_qty = 1
750        convert_val.min_qty = 1
751        convert_val.base_name = None
752
753    # A dict with lists of regexes and their corresponding handlers.
754    # This is a dict in case multiple inheritors of CddlParser are used at once, in which case
755    # they may have slightly different handlers.
756    cddl_regexes = dict()
757
758    def cddl_regexes_init(self):
759        match_uint = r"(0x[0-9a-fA-F]+|0o[0-7]+|0b[01]+|\d+)"
760        match_int = r"(-?" + match_uint + ")"
761        match_nint = r"(-" + match_uint + ")"
762
763        self_type = type(self)
764
765        # The "range_types" match the contents of brackets i.e. (), [], and {},
766        # and strings, i.e. ' or "
767        range_types = [
768            (r'\[(?P<item>(?>[^[\]]+|(?R))*)\]',
769             lambda m_self, list_str: m_self.type_and_value(
770                 "LIST", lambda: m_self.parse(list_str))),
771            (r'\((?P<item>(?>[^\(\)]+|(?R))*)\)',
772             lambda m_self, group_str: m_self.type_and_value(
773                 "GROUP", lambda: m_self.parse(group_str))),
774            (r'{(?P<item>(?>[^{}]+|(?R))*)}',
775             lambda m_self, map_str: m_self.type_and_value(
776                 "MAP", lambda: m_self.parse(map_str))),
777            (r'\'(?P<item>.*?)(?<!\\)\'',
778             lambda m_self, string: m_self.type_and_value("BSTR", lambda: string)),
779            (r'\"(?P<item>.*?)(?<!\\)\"',
780             lambda m_self, string: m_self.type_and_value("TSTR", lambda: string)),
781        ]
782        range_types_regex = '|'.join([regex for (regex, _) in range_types])
783        for i in range(range_types_regex.count("item")):
784            range_types_regex = range_types_regex.replace("item", "it%dem" % i, 1)
785
786        # The following regexes match different parts of the element. The order of the list is
787        # important because it implements the operator precendence defined in the CDDL spec.
788        # The range_types are separate because they are reused in one of the other regexes.
789        self_type.cddl_regexes[self_type] = range_types + [
790            (r'\/\/\s*(?P<item>.+?)(?=\/\/|\Z)',
791             lambda m_self, union_str: m_self.union_add_value(
792                 m_self.parse("(%s)" % union_str if ',' in union_str else union_str)[0],
793                 doubleslash=True)),
794            (r'(?P<item>[^\W\d][\w-]*)\s*:',
795             self_type.set_key_or_label),
796            (r'((\=\>)|:)',
797             lambda m_self, _: m_self.convert_to_key()),
798            (r'([+*?])',
799             self_type.set_quantifier),
800            (r'(' + match_uint + r'\*\*?' + match_uint + r'?)',
801             self_type.set_quantifier),
802            (r'\/\s*(?P<item>((' + range_types_regex + r')|[^,\[\]{}()])+?)(?=\/|\Z|,)',
803             lambda m_self, union_str: m_self.union_add_value(
804                 m_self.parse(union_str)[0])),
805            (r'(uint|nint|int|float|bstr|tstr|bool|nil|any)(?![\w-])',
806             lambda m_self, type_str: m_self.type_and_value(type_str.upper(), lambda: None)),
807            (r'undefined(?!\w)',
808             lambda m_self, _: m_self.type_and_value("UNDEF", lambda: None)),
809            (r'float16(?![\w-])',
810             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 2)),
811            (r'float16-32(?![\w-])',
812             lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 2, 4)),
813            (r'float32(?![\w-])',
814             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 4)),
815            (r'float32-64(?![\w-])',
816             lambda m_self, _: m_self.type_value_size_range("FLOAT", lambda: None, 4, 8)),
817            (r'float64(?![\w-])',
818             lambda m_self, _: m_self.type_value_size("FLOAT", lambda: None, 8)),
819            (r'\-?\d*\.\d+',
820             lambda m_self, num: m_self.type_and_value("FLOAT", lambda: float(num))),
821            (match_uint + r'\.\.' + match_uint,
822             lambda m_self, _range: m_self.type_and_range(
823                 "UINT", *map(lambda num: int(num, 0), _range.split("..")))),
824            (match_nint + r'\.\.' + match_uint,
825             lambda m_self, _range: m_self.type_and_range(
826                 "INT", *map(lambda num: int(num, 0), _range.split("..")))),
827            (match_nint + r'\.\.' + match_nint,
828             lambda m_self, _range: m_self.type_and_range(
829                 "NINT", *map(lambda num: int(num, 0), _range.split("..")))),
830            (match_uint + r'\.\.\.' + match_uint,
831             lambda m_self, _range: m_self.type_and_range(
832                 "UINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
833            (match_nint + r'\.\.\.' + match_uint,
834             lambda m_self, _range: m_self.type_and_range(
835                 "INT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
836            (match_nint + r'\.\.\.' + match_nint,
837             lambda m_self, _range: m_self.type_and_range(
838                 "NINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
839            (match_nint,
840             lambda m_self, num: m_self.type_and_value("NINT", lambda: int(num, 0))),
841            (match_uint,
842             lambda m_self, num: m_self.type_and_value("UINT", lambda: int(num, 0))),
843            (r'true(?!\w)',
844             lambda m_self, _: m_self.type_and_value("BOOL", lambda: True)),
845            (r'false(?!\w)',
846             lambda m_self, _: m_self.type_and_value("BOOL", lambda: False)),
847            (r'#6\.(?P<item>\d+)',
848             self_type.add_tag),
849            (r'(\$?\$?[\w-]+)',
850             lambda m_self, other_str: m_self.type_and_value("OTHER", lambda: other_str)),
851            (r'\.size \(?(?P<item>' + match_int + r'\.\.' + match_int + r')\)?',
852             lambda m_self, _range: m_self.set_size_range(
853                 *map(lambda num: int(num, 0), _range.split("..")))),
854            (r'\.size \(?(?P<item>' + match_int + r'\.\.\.' + match_int + r')\)?',
855             lambda m_self, _range: m_self.set_size_range(
856                 *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
857            (r'\.size \(?(?P<item>' + match_uint + r')\)?',
858             lambda m_self, size: m_self.set_size(int(size, 0))),
859            (r'\.gt \(?(?P<item>' + match_int + r')\)?',
860             lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0) + 1)),
861            (r'\.lt \(?(?P<item>' + match_int + r')\)?',
862             lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0) - 1)),
863            (r'\.ge \(?(?P<item>' + match_int + r')\)?',
864             lambda m_self, minvalue: m_self.set_min_value(int(minvalue, 0))),
865            (r'\.le \(?(?P<item>' + match_int + r')\)?',
866             lambda m_self, maxvalue: m_self.set_max_value(int(maxvalue, 0))),
867            (r'\.eq \(?(?P<item>' + match_int + r')\)?',
868             lambda m_self, value: m_self.set_value(lambda: int(value, 0))),
869            (r'\.eq \"(?P<item>.*?)(?<!\\)\"',
870             lambda m_self, value: m_self.set_value(lambda: value)),
871            (r'\.cbor (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
872             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)),
873            (r'\.cbor (?P<item>[^\s,]+)',
874             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], False)),
875            (r'\.cborseq (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
876             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)),
877            (r'\.cborseq (?P<item>[^\s,]+)',
878             lambda m_self, type_str: m_self.set_cbor(m_self.parse(type_str)[0], True)),
879            (r'\.bits (?P<item>[\w-]+)',
880             lambda m_self, bits_str: m_self.set_bits(bits_str))
881        ]
882
883    # Parse from the beginning of instr (string) until a full element has been parsed. self will
884    # become that element. This function is recursive, so if a nested element
885    # ("MAP"/"LIST"/"UNION"/"GROUP") is encountered, this function will create new instances and add
886    # them to self.value as a list. Likewise, if a key or cbor definition is encountered, a new
887    # element will be created and assigned to self.key or self.cbor. When new elements are created,
888    # get_value() is called on those elements, via parse().
889    def get_value(self, instr):
890        types = type(self).cddl_regexes[type(self)]
891
892        # Keep parsing until a comma, or to the end of the string.
893        while instr != '' and instr[0] != ',':
894            match_obj = None
895            for (reg, handler) in types:
896                match_obj = getrp(reg).match(instr)
897                if match_obj:
898                    try:
899                        match_str = match_obj.group("item")
900                    except IndexError:
901                        match_str = match_obj.group(0)
902                    try:
903                        handler(self, match_str)
904                    except Exception as e:
905                        raise Exception("Failed while parsing this: '%s'" % match_str) from e
906                    self.match_str += match_str
907                    old_len = len(instr)
908                    instr = getrp(reg).sub('', instr, count=1).lstrip()
909                    if old_len == len(instr):
910                        raise Exception("empty match")
911                    break
912
913            if not match_obj:
914                raise TypeError("Could not parse this: '%s'" % instr)
915
916        instr = instr[1:]
917        if not self.type:
918            raise ValueError("No proper value while parsing: %s" % instr)
919
920        # Return the unparsed part of the string.
921        return instr.strip()
922
923    # For checking whether this element has a key (i.e. to check that it is a valid "MAP" child).
924    # This must have some recursion since CDDL allows the key to be hidden
925    # behind layers of indirection.
926    def elem_has_key(self):
927        return self.key is not None\
928            or (self.type == "OTHER" and self.my_types[self.value].elem_has_key())\
929            or (self.type in ["GROUP", "UNION"]
930                and all(child.elem_has_key() for child in self.value))
931
932    # Function for performing validations that must be done after all parsing is complete. This is
933    # recursive, so it will post_validate all its children + key + cbor.
934    def post_validate(self):
935        # Validation of this element.
936        if self.type in ["LIST", "MAP"]:
937            none_keys = [child for child in self.value if not child.elem_has_key()]
938            child_keys = [child for child in self.value if child not in none_keys]
939            if self.type == "MAP" and none_keys:
940                raise TypeError(
941                    "Map member(s) must have key: " + str(none_keys) + " pointing to "
942                    + str(
943                        [self.my_types[elem.value] for elem in none_keys
944                            if elem.type == "OTHER"]))
945            if self.type == "LIST" and child_keys:
946                raise TypeError(
947                    str(self) + linesep
948                    + "List member(s) cannot have key: " + str(child_keys) + " pointing to "
949                    + str(
950                        [self.my_types[elem.value] for elem in child_keys
951                            if elem.type == "OTHER"]))
952        if self.type == "OTHER":
953            if self.value not in self.my_types.keys() or not isinstance(
954                    self.my_types[self.value], type(self)):
955                raise TypeError("%s has not been parsed." % self.value)
956        if self.type == "LIST":
957            for child in self.value[:-1]:
958                if child.type == "ANY":
959                    if child.min_qty != child.max_qty:
960                        raise TypeError(f"ambiguous quantity of 'any' is not supported in list, "
961                                        + "except as last element:\n{str(child)}")
962        if self.type == "UNION" and len(self.value) > 1:
963            if any(((not child.key and child.type == "ANY") or (
964                    child.key and child.key.type == "ANY")) for child in self.value):
965                raise TypeError(
966                    "'any' inside union is not supported since it would always be triggered.")
967
968        # Validation of child elements.
969        if self.type in ["MAP", "LIST", "UNION", "GROUP"]:
970            for child in self.value:
971                child.post_validate()
972        if self.key:
973            self.key.post_validate()
974        if self.cbor:
975            self.cbor.post_validate()
976
977    def post_validate_control_group(self):
978        if self.type != "GROUP":
979            raise TypeError("control groups must be of GROUP type.")
980        for c in self.value:
981            if c.type != "UINT" or c.value is None or c.value < 0:
982                raise TypeError("control group members must be literal positive integers.")
983
984    # Parses entire instr and returns a list of instances.
985    def parse(self, instr):
986        instr = instr.strip()
987        values = []
988        while instr != '':
989            value = type(self)(*self.init_args(), **self.init_kwargs(), base_stem=self.base_stem)
990            instr = value.get_value(instr)
991            values.append(value)
992        return values
993
994    def __repr__(self):
995        return self.mrepr(False)
996
997
998class CddlXcoder(CddlParser):
999
1000    def __init__(self, *args, **kwargs):
1001        super(CddlXcoder, self).__init__(*args, **kwargs)
1002
1003        # The prefix used for C code accessing this element, i.e. the struct
1004        # hierarchy leading up to this element.
1005        self.accessPrefix = None
1006        # Used as a guard against endless recursion in self.dependsOn()
1007        self.dependsOnCall = False
1008        self.skipped = False
1009
1010    # Name of variables and enum members for this element.
1011    def var_name(self, with_prefix=False):
1012        name = self.id(with_prefix=with_prefix)
1013        if name in ["int", "bool", "float"]:
1014            name = name.capitalize()
1015        return name
1016
1017    def skip_condition(self):
1018        if self.skipped:
1019            return True
1020        if self.type in ["LIST", "MAP", "GROUP"]:
1021            return not self.multi_val_condition()
1022        if self.type == "OTHER":
1023            return ((not self.repeated_multi_var_condition())
1024                    and (not self.multi_var_condition())
1025                    and (self.single_func_impl_condition() or self in self.my_types.values()))
1026        return False
1027
1028    def set_skipped(self, skipped):
1029        if self.range_check_condition() \
1030           and self.repeated_single_func_impl_condition() \
1031           and not self.key:
1032            self.skipped = True
1033        else:
1034            self.skipped = skipped
1035
1036    # Recursively set the access prefix for this element and all its children.
1037    def set_access_prefix(self, prefix):
1038        self.accessPrefix = prefix
1039        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
1040            list(map(lambda child: child.set_access_prefix(self.var_access()),
1041                     self.value))
1042            list(map(lambda child: child.set_skipped(self.skip_condition()),
1043                     self.value))
1044        elif self in self.my_types.values() and self.type != "OTHER":
1045            self.set_skipped(not self.multi_member())
1046        if self.key is not None:
1047            self.key.set_access_prefix(self.var_access())
1048        if self.cbor_var_condition():
1049            self.cbor.set_access_prefix(self.var_access())
1050
1051    # Whether this type has multiple member variables.
1052    def multi_member(self):
1053        return self.multi_var_condition() or self.repeated_multi_var_condition()
1054
1055    def is_unambiguous_value(self):
1056        return (self.type in ["NIL", "UNDEF", "ANY"]
1057                or (self.type in ["INT", "NINT", "UINT", "FLOAT", "BSTR", "TSTR", "BOOL"]
1058                    and self.value is not None)
1059                or (self.type == "OTHER" and self.my_types[self.value].is_unambiguous()))
1060
1061    def is_unambiguous_repeated(self):
1062        return (self.is_unambiguous_value()
1063                and (self.key is None or self.key.is_unambiguous_repeated())
1064                or (self.type in ["LIST", "GROUP", "MAP"] and len(self.value) == 0)
1065                or (self.type in ["LIST", "GROUP", "MAP"]
1066                    and all((child.is_unambiguous() for child in self.value))))
1067
1068    def is_unambiguous(self):
1069        return (self.is_unambiguous_repeated() and (self.min_qty == self.max_qty))
1070
1071    # Create an access prefix based on an existing prefix, delimiter and a
1072    # suffix.
1073    def access_append_delimiter(self, prefix, delimiter, *suffix):
1074        assert prefix is not None, "No access prefix for %s" % self.var_name()
1075        return delimiter.join((prefix,) + suffix)
1076
1077    # Create an access prefix from this element's prefix, delimiter and a
1078    # provided suffix.
1079    def access_append(self, *suffix):
1080        suffix = list(suffix)
1081        return self.access_append_delimiter(self.accessPrefix, '.', *suffix)
1082
1083    # "Path" to this element's variable.
1084    # If full is false, the path to the repeated part is returned.
1085    def var_access(self):
1086        if self.is_unambiguous():
1087            return "NULL"
1088        return self.access_append()
1089
1090    # "Path" to access this element's actual value variable.
1091    def val_access(self):
1092        if self.is_unambiguous_repeated():
1093            return "NULL"
1094        if self.skip_condition():
1095            return self.var_access()
1096        return self.access_append(self.var_name())
1097
1098    def repeated_val_access(self):
1099        if self.is_unambiguous_repeated():
1100            return "NULL"
1101        return self.access_append(self.var_name())
1102
1103    # Whether to include a "present" variable for this element.
1104    def present_var_condition(self):
1105        return self.min_qty == 0 and isinstance(self.max_qty, int) and self.max_qty <= 1
1106
1107    # Whether to include a "count" variable for this element.
1108    def count_var_condition(self):
1109        return isinstance(self.max_qty, str) or self.max_qty > 1
1110
1111    # Whether to include a "cbor" variable for this element.
1112    def is_cbor(self):
1113        return (self.type not in ["NIL", "UNDEF", "ANY"]) \
1114            and ((self.type != "OTHER") or (self.my_types[self.value].is_cbor()))
1115
1116    # Whether to include a "cbor" variable for this element.
1117    def cbor_var_condition(self):
1118        return (self.cbor is not None) and self.cbor.is_cbor()
1119
1120    # Whether to include a "choice" variable for this element.
1121    def choice_var_condition(self):
1122        return self.type == "UNION"
1123
1124    # Whether this specific type is a key.
1125    def reduced_key_var_condition(self):
1126        if self.key is not None:
1127            return True
1128        return False
1129
1130    # Whether to include a "key" variable for this element.
1131    def key_var_condition(self):
1132        if self.reduced_key_var_condition():
1133            return True
1134        if self.type == "OTHER" and self.my_types[self.value].key_var_condition():
1135            return True
1136        if (self.type in ["GROUP", "UNION"]
1137                and len(self.value) >= 1
1138                and self.value[0].reduced_key_var_condition()):
1139            return True
1140        return False
1141
1142    # Whether this value adds any repeated elements by itself. I.e. excluding
1143    # multiple elements from children.
1144    def self_repeated_multi_var_condition(self):
1145        return (self.key_var_condition()
1146                or self.cbor_var_condition()
1147                or self.choice_var_condition())
1148
1149    # Whether this element's actual value has multiple members.
1150    def multi_val_condition(self):
1151        return (
1152            self.type in ["LIST", "MAP", "GROUP", "UNION"]
1153            and (len(self.value) > 1 or (len(self.value) == 1 and self.value[0].multi_member())))
1154
1155    # Whether any extra variables are to be included for this element for each
1156    # repetition.
1157    def repeated_multi_var_condition(self):
1158        return self.self_repeated_multi_var_condition() or self.multi_val_condition()
1159
1160    # Whether any extra variables are to be included for this element outside
1161    # of repetitions.
1162    # Also, whether this element must involve a call to multi_xcode(), i.e. unless
1163    # it's repeated exactly once.
1164    def multi_var_condition(self):
1165        return self.present_var_condition() or self.count_var_condition()
1166
1167    # Whether this element needs a check (memcmp) for a string value.
1168    def range_check_condition(self):
1169        if self.type == "OTHER":
1170            return self.my_types[self.value].range_check_condition()
1171        if self.type not in ["INT", "NINT", "UINT", "BSTR", "TSTR"]:
1172            return False
1173        if self.value is not None:
1174            return False
1175        if self.type in ["INT", "NINT", "UINT"] \
1176                and (self.min_value is not None or self.max_value is not None):
1177            return True
1178        if self.type == "UINT" and self.bits:
1179            return True
1180        if self.type in ["BSTR", "TSTR"] \
1181                and (self.min_size is not None or self.max_size is not None):
1182            return True
1183        return False
1184
1185    # Whether this element should have a typedef in the code.
1186    def type_def_condition(self):
1187        if self in self.my_types.values() and self.multi_member() and not self.is_unambiguous():
1188            return True
1189        return False
1190
1191    # Whether this type needs a typedef for its repeated part.
1192    def repeated_type_def_condition(self):
1193        return (
1194            self.repeated_multi_var_condition()
1195            and self.multi_var_condition()
1196            and not self.is_unambiguous_repeated())
1197
1198    # Whether this element needs its own encoder/decoder function.
1199    def single_func_impl_condition(self):
1200        return (
1201            False
1202            or self.reduced_key_var_condition()
1203            or self.cbor_var_condition()
1204            or (self.tags and self in self.my_types.values())
1205            or self.type_def_condition()
1206            or (self.type in ["LIST", "MAP", "GROUP"] and len(self.value) != 0))
1207
1208    # Whether this element needs its own encoder/decoder function.
1209    def repeated_single_func_impl_condition(self):
1210        return self.repeated_type_def_condition() \
1211            or (self.type in ["LIST", "MAP", "GROUP"] and self.multi_member()) \
1212            or (
1213                self.multi_var_condition()
1214                and (self.self_repeated_multi_var_condition() or self.range_check_condition()))
1215
1216    def int_val(self):
1217        if self.key:
1218            return self.key.int_val()
1219        elif self.type in ("UINT", "NINT") and self.is_unambiguous():
1220            return self.value
1221        elif self.type == "GROUP" and not self.count_var_condition():
1222            return self.value[0].int_val()
1223        elif self.type == "OTHER" \
1224                and not self.count_var_condition() \
1225                and not self.single_func_impl_condition() \
1226                and not self.my_types[self.value].single_func_impl_condition():
1227            return self.my_types[self.value].int_val()
1228        return None
1229
1230    def is_int_disambiguated(self):
1231        return self.int_val() is not None
1232
1233    def all_children_disambiguated(self, min_val, max_val):
1234        values = set(child.int_val() for child in self.value)
1235        retval = (len(values) == len(self.value)) and None not in values \
1236            and max(values) <= max_val and min(values) >= min_val
1237        return retval
1238
1239    def all_children_int_disambiguated(self):
1240        return self.all_children_disambiguated(INT32_MIN, INT32_MAX)
1241
1242    def all_children_uint_disambiguated(self):
1243        return self.all_children_disambiguated(0, INT32_MAX)
1244
1245    # Name of the "present" variable for this element.
1246    def present_var_name(self):
1247        return "%s_present" % (self.var_name())
1248
1249    # Full "path" of the "present" variable for this element.
1250    def present_var_access(self):
1251        return self.access_append(self.present_var_name())
1252
1253    # Name of the "count" variable for this element.
1254    def count_var_name(self):
1255        return "%s_count" % (self.var_name())
1256
1257    # Full "path" of the "count" variable for this element.
1258    def count_var_access(self):
1259        return self.access_append(self.count_var_name())
1260
1261    # Name of the "choice" variable for this element.
1262    def choice_var_name(self):
1263        return self.var_name() + "_choice"
1264
1265    # Name of the enum entry for this element.
1266    def enum_var_name(self):
1267        return self.var_name(with_prefix=True) + "_c"
1268
1269    # Enum entry for this element.
1270    def enum_var(self, int_val=False):
1271        return f"{self.enum_var_name()} = {val_to_str(self.int_val())}" \
1272               if int_val else self.enum_var_name()
1273
1274    # Full "path" of the "choice" variable for this element.
1275    def choice_var_access(self):
1276        return self.access_append(self.choice_var_name())
1277
1278
1279class CddlValidationError(Exception):
1280    pass
1281
1282
1283# Subclass of tuple for holding key,value pairs. This is to make it possible to use isinstance() to
1284# separate it from other tuples.
1285class KeyTuple(tuple):
1286    def __new__(cls, *in_tuple):
1287        return super(KeyTuple, cls).__new__(cls, *in_tuple)
1288
1289
1290# Convert data between CBOR, JSON and YAML, and validate against the provided CDDL.
1291# Decode and validate CBOR into Python structures to be able to make Python scripts that manipulate
1292# CBOR code.
1293class DataTranslator(CddlXcoder):
1294    # Format a Python object for printing by adding newlines and indentation.
1295    @staticmethod
1296    def format_obj(obj):
1297        formatted = pformat(obj)
1298        out_str = ""
1299        indent = 0
1300        new_line = True
1301        for c in formatted:
1302            if new_line:
1303                if c == " ":
1304                    continue
1305                new_line = False
1306            out_str += c
1307            if c in "[(":
1308                indent += 1
1309            if c in ")]" and indent > 0:
1310                indent -= 1
1311            if c in "[(,":
1312                out_str += linesep
1313                out_str += "  " * indent
1314                new_line = True
1315        return out_str
1316
1317    # Override the id() function. If the name starts with an underscore, prepend an 'f',
1318    # since namedtuple() doesn't support identifiers that start with an underscore.
1319    def id(self):
1320        return getrp(r"\A_").sub("f_", self.generate_base_name())
1321
1322    # Override the var_name()
1323    def var_name(self):
1324        return self.id()
1325
1326    # Check a condition and raise a CddlValidationError if not.
1327    def _decode_assert(self, test, msg=""):
1328        if not test:
1329            raise CddlValidationError(f"Data did not decode correctly {'('+msg+')' if msg else ''}")
1330
1331    # Check that no unexpected tags are attached to this data. Return whether a tag was present.
1332    def _check_tag(self, obj):
1333        tags = copy(self.tags)  # All expected tags
1334        # Process all tags present in obj
1335        while isinstance(obj, CBORTag):
1336            if obj.tag in tags or self.type == "ANY":
1337                if obj.tag in tags:
1338                    tags.remove(obj.tag)
1339                obj = obj.value
1340                continue
1341            elif self.type in ["OTHER", "GROUP", "UNION"]:
1342                break
1343            self._decode_assert(False, f"Tag ({obj.tag}) not expected for {self}")
1344        # Check that all expected tags were found in obj.
1345        self._decode_assert(not tags, f"Expected tags ({tags}), but none present.")
1346        return obj
1347
1348    # Return our expected python type as returned by cbor2.
1349    def _expected_type(self):
1350        return {
1351            "UINT": lambda: (int,),
1352            "INT": lambda: (int,),
1353            "NINT": lambda: (int,),
1354            "FLOAT": lambda: (float,),
1355            "TSTR": lambda: (str,),
1356            "BSTR": lambda: (bytes,),
1357            "NIL": lambda: (type(None),),
1358            "UNDEF": lambda: (type(undefined),),
1359            "ANY": lambda: (int, float, str, bytes, type(None), type(undefined), bool, list, dict),
1360            "BOOL": lambda: (bool,),
1361            "LIST": lambda: (tuple, list),
1362            "MAP": lambda: (dict,),
1363        }[self.type]()
1364
1365    # Check that the decoded object has the correct type.
1366    def _check_type(self, obj):
1367        if self.type not in ["OTHER", "GROUP", "UNION"]:
1368            exp_type = self._expected_type()
1369            self._decode_assert(
1370                type(obj) in exp_type,
1371                f"{str(self)}: Wrong type ({type(obj)}) of {str(obj)}, expected {str(exp_type)}")
1372
1373    # Check that the decode value conforms to the restrictions in the CDDL.
1374    def _check_value(self, obj):
1375        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", "BSTR", "BOOL"] \
1376                and self.value is not None:
1377            value = self.value
1378            if self.type == "BSTR":
1379                value = self.value.encode("utf-8")
1380            self._decode_assert(
1381                self.value == obj,
1382                f"{obj} should have value {self.value} according to {self.var_name()}")
1383        if self.type in ["UINT", "INT", "NINT", "FLOAT"]:
1384            if self.min_value is not None:
1385                self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value))
1386            if self.max_value is not None:
1387                self._decode_assert(obj <= self.max_value, "Maximum value: " + str(self.max_value))
1388        if self.type == "UINT":
1389            if self.bits:
1390                mask = sum(((1 << b.value) for b in self.my_control_groups[self.bits].value))
1391                self._decode_assert(not (obj & ~mask), "Allowed bitmask: " + bin(mask))
1392        if self.type in ["TSTR", "BSTR"]:
1393            if self.min_size is not None:
1394                self._decode_assert(
1395                    len(obj) >= self.min_size, "Minimum length: " + str(self.min_size))
1396            if self.max_size is not None:
1397                self._decode_assert(
1398                    len(obj) <= self.max_size, "Maximum length: " + str(self.max_size))
1399
1400    # Check that the object is not a KeyTuple since that means it hasn't been properly processed.
1401    def _check_key(self, obj):
1402        self._decode_assert(
1403            not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj))
1404
1405    # Recursively remove intermediate objects that have single members. Keep lists as is.
1406    def _flatten_obj(self, obj):
1407        if isinstance(obj, tuple) and len(obj) == 1:
1408            return self._flatten_obj(obj[0])
1409        return obj
1410
1411    # Return the contents of a list if it has a single member and it's name is the same as us.
1412    def _flatten_list(self, name, obj):
1413        if (isinstance(obj, list)
1414                and len(obj) == 1
1415                and (isinstance(obj[0], list) or isinstance(obj[0], tuple))
1416                and len(obj[0]) == 1
1417                and hasattr(obj[0], name)):
1418            return [obj[0][0]]
1419        return obj
1420
1421    # Construct a namedtuple object from my_list. my_list contains tuples of name/value.
1422    # Also, attempt to flatten redundant levels of abstraction.
1423    def _construct_obj(self, my_list):
1424        if my_list == []:
1425            return None
1426        names, values = tuple(zip(*my_list))
1427        if len(values) == 1:
1428            values = (self._flatten_obj(values[0]), )
1429        values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values)))
1430        assert (not any((isinstance(elem, KeyTuple) for elem in values))), \
1431            f"KeyTuple not processed: {values}"
1432        return namedtuple("_", names)(*values)
1433
1434    # Add construct obj and add it to my_list if relevant. Also, process any KeyTuples present.
1435    def _add_if(self, my_list, obj, expect_key=False, name=None):
1436        if expect_key and self.type == "OTHER" and self.key is None:
1437            self.my_types[self.value]._add_if(my_list, obj)
1438            return
1439        if self.is_unambiguous():
1440            return
1441        if isinstance(obj, list):
1442            for i in range(len(obj)):
1443                if isinstance(obj[i], KeyTuple):
1444                    retvals = list()
1445                    self._add_if(retvals, obj[i])
1446                    obj[i] = self._construct_obj(retvals)
1447                if self.type == "BSTR" and self.cbor_var_condition() and isinstance(obj[i], bytes):
1448                    assert all((isinstance(o, bytes) for o in obj)), \
1449                           """Unsupported configuration for cbor bstr. If a list contains a
1450CBOR-formatted bstr, all elements must be bstrs. If not, it is a programmer error."""
1451        if isinstance(obj, KeyTuple):
1452            key, obj = obj
1453            if key is not None:
1454                self.key._add_if(my_list, key, name=self.var_name() + "_key")
1455        if self.type == "BSTR" and self.cbor_var_condition():
1456            # If a bstr is CBOR-formatted, add both the string and the decoding of the string here
1457            if isinstance(obj, list) and all((isinstance(o, bytes) for o in obj)):
1458                # One or more bstr in a list (i.e. it is optional or repeated)
1459                my_list.append((name or self.var_name(), [self.cbor.decode_str(o) for o in obj]))
1460                my_list.append(((name or self.var_name()) + "_bstr", obj))
1461                return
1462            if isinstance(obj, bytes):
1463                my_list.append((name or self.var_name(), self.cbor.decode_str(obj)))
1464                my_list.append(((name or self.var_name()) + "_bstr", obj))
1465                return
1466        my_list.append((name or self.var_name(), obj))
1467
1468    # Throw CddlValidationError if iterator is not empty. This consumes one element if present.
1469    def _iter_is_empty(self, it):
1470        try:
1471            val = next(it)
1472        except StopIteration:
1473            return True
1474        raise CddlValidationError(
1475            f"Iterator not consumed while parsing \n{self}\nRemaining elements:\n elem: "
1476            + "\n elem: ".join(str(elem) for elem in ([val] + list(it))))
1477
1478    # Get next element from iterator, throw CddlValidationError instead of StopIteration.
1479    def _iter_next(self, it):
1480        try:
1481            next_obj = next(it)
1482            return next_obj
1483        except StopIteration:
1484            raise CddlValidationError("Iterator empty")
1485
1486    # Decode single CDDL value, excluding repetitions
1487    def _decode_single_obj(self, obj):
1488        self._check_key(obj)
1489        obj = self._check_tag(obj)
1490        self._check_type(obj)
1491        self._check_value(obj)
1492        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR",
1493                         "BSTR", "BOOL", "NIL", "UNDEF", "ANY"]:
1494            return obj
1495        elif self.type == "OTHER":
1496            return self.my_types[self.value]._decode_single_obj(obj)
1497        elif self.type == "LIST":
1498            retval = list()
1499            child_val = iter(obj)
1500            for child in self.value:
1501                ret = child._decode_full(child_val)
1502                child_val, child_obj = ret
1503                child._add_if(retval, child_obj)
1504            self._iter_is_empty(child_val)
1505            return self._construct_obj(retval)
1506        elif self.type == "MAP":
1507            retval = list()
1508            child_val = iter(KeyTuple(item) for item in obj.items())
1509            for child in self.value:
1510                child_val, child_key_val = child._decode_full(child_val)
1511                child._add_if(retval, child_key_val, expect_key=True)
1512            self._iter_is_empty(child_val)
1513            return self._construct_obj(retval)
1514        elif self.type == "UNION":
1515            retval = list()
1516            for child in self.value:
1517                try:
1518                    child_obj = child._decode_single_obj(obj)
1519                    child._add_if(retval, child_obj)
1520                    retval.append(("union_choice", child.var_name()))
1521                    return self._construct_obj(retval)
1522                except CddlValidationError as c:
1523                    self.errors.append(str(c))
1524            self._decode_assert(False, "No matches for union: " + str(self))
1525        assert False, "Unexpected type: " + self.type
1526
1527    # Decode key and value in the form of a KeyTuple
1528    def _handle_key(self, next_obj):
1529        self._decode_assert(
1530            isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj))
1531        key, obj = next_obj
1532        key_res = self.key._decode_single_obj(key)
1533        obj_res = self._decode_single_obj(obj)
1534        res = KeyTuple((key_res if not self.key.is_unambiguous() else None, obj_res))
1535        return res
1536
1537    # Decode single CDDL value, excluding repetitions. May consume 0 to n CBOR objects via the
1538    # iterator.
1539    def _decode_obj(self, it):
1540        my_list = list()
1541        if self.key is not None:
1542            it, it_copy = tee(it)
1543            key_res = self._handle_key(self._iter_next(it_copy))
1544            return it_copy, key_res
1545        if self.tags:
1546            it, it_copy = tee(it)
1547            maybe_tag = next(it_copy)
1548            if isinstance(maybe_tag, CBORTag):
1549                tag_res = self._decode_single_obj(maybe_tag)
1550                return it_copy, tag_res
1551        if self.type == "OTHER" and self.key is None:
1552            return self.my_types[self.value]._decode_full(it)
1553        elif self.type == "GROUP":
1554            my_list = list()
1555            child_it = it
1556            for child in self.value:
1557                child_it, child_obj = child._decode_full(child_it)
1558                if child.key is not None:
1559                    child._add_if(my_list, child_obj, expect_key=True)
1560                else:
1561                    child._add_if(my_list, child_obj)
1562            ret = (child_it, self._construct_obj(my_list))
1563        elif self.type == "UNION":
1564            my_list = list()
1565            child_it = it
1566            found = False
1567            for child in self.value:
1568                try:
1569                    child_it, it_copy = tee(child_it)
1570                    child_it, child_obj = child._decode_full(child_it)
1571                    child._add_if(my_list, child_obj)
1572                    my_list.append(("union_choice", child.var_name()))
1573                    ret = (child_it, self._construct_obj(my_list))
1574                    found = True
1575                    break
1576                except CddlValidationError as c:
1577                    self.errors.append(str(c))
1578                    child_it = it_copy
1579            self._decode_assert(found, "No matches for union: " + str(self))
1580        else:
1581            ret = (it, self._decode_single_obj(self._iter_next(it)))
1582        return ret
1583
1584    # Decode single CDDL value, with repetitions. May consume 0 to n CBOR objects via the iterator.
1585    def _decode_full(self, it):
1586        if self.multi_var_condition():
1587            retvals = []
1588            for i in range(self.min_qty):
1589                it, retval = self._decode_obj(it)
1590                retvals.append(retval if not self.is_unambiguous_repeated() else None)
1591            try:
1592                for i in range(self.max_qty - self.min_qty):
1593                    it, it_copy = tee(it)
1594                    it, retval = self._decode_obj(it)
1595                    retvals.append(retval if not self.is_unambiguous_repeated() else None)
1596            except CddlValidationError as c:
1597                self.errors.append(str(c))
1598                it = it_copy
1599            return it, retvals
1600        else:
1601            ret = self._decode_obj(it)
1602            return ret
1603
1604    # CBOR object => python object
1605    def decode_obj(self, obj):
1606        it = iter([obj])
1607        try:
1608            _, decoded = self._decode_full(it)
1609            self._iter_is_empty(it)
1610        except CddlValidationError as e:
1611            if self.errors:
1612                print("Errors:")
1613                pprint(self.errors)
1614            raise e
1615        return decoded
1616
1617    # YAML => python object
1618    def decode_str_yaml(self, yaml_str, yaml_compat=False):
1619        yaml_obj = yaml_load(yaml_str)
1620        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1621        self.validate_obj(obj)
1622        return self.decode_obj(obj)
1623
1624    # CBOR bytestring => python object
1625    def decode_str(self, cbor_str):
1626        cbor_obj = loads(cbor_str)
1627        return self.decode_obj(cbor_obj)
1628
1629    # Validate CBOR object against CDDL. Exception if not valid.
1630    def validate_obj(self, obj):
1631        self.decode_obj(obj)
1632        return True
1633
1634    # Validate CBOR bytestring against CDDL. Exception if not valid.
1635    def validate_str(self, cbor_str):
1636        cbor_obj = loads(cbor_str)
1637        return self.validate_obj(cbor_obj)
1638
1639    # Convert object from YAML/JSON (with special dicts for bstr, tag etc) to CBOR object
1640    # that cbor2 understands.
1641    def _from_yaml_obj(self, obj):
1642        if isinstance(obj, list):
1643            if len(obj) == 1 and obj[0] == "zcbor_undefined":
1644                return undefined
1645            return [self._from_yaml_obj(elem) for elem in obj]
1646        elif isinstance(obj, dict):
1647            if ["zcbor_bstr"] == list(obj.keys()):
1648                if isinstance(obj["zcbor_bstr"], str):
1649                    bstr = bytes.fromhex(obj["zcbor_bstr"])
1650                else:
1651                    bstr = dumps(self._from_yaml_obj(obj["zcbor_bstr"]))
1652                return bstr
1653            elif ["zcbor_tag", "zcbor_tag_val"] == list(obj.keys()):
1654                return CBORTag(obj["zcbor_tag"], self._from_yaml_obj(obj["zcbor_tag_val"]))
1655            retval = dict()
1656            for key, val in obj.items():
1657                match = getrp(r"zcbor_keyval\d+").fullmatch(key)
1658                if match is not None:
1659                    new_key = self._from_yaml_obj(val["key"])
1660                    new_val = self._from_yaml_obj(val["val"])
1661                    if isinstance(new_key, list):
1662                        new_key = tuple(new_key)
1663                    retval[new_key] = new_val
1664                else:
1665                    retval[key] = self._from_yaml_obj(val)
1666            return retval
1667        return obj
1668
1669    # inverse of _from_yaml_obj
1670    def _to_yaml_obj(self, obj):
1671        if isinstance(obj, list) or isinstance(obj, tuple):
1672            return [self._to_yaml_obj(elem) for elem in obj]
1673        elif isinstance(obj, dict):
1674            retval = dict()
1675            i = 0
1676            for key, val in obj.items():
1677                if not isinstance(key, str):
1678                    retval[f"zcbor_keyval{i}"] = {
1679                        "key": self._to_yaml_obj(key), "val": self._to_yaml_obj(val)}
1680                    i += 1
1681                else:
1682                    retval[key] = self._to_yaml_obj(val)
1683            return retval
1684        elif isinstance(obj, bytes):
1685            f = BytesIO(obj)
1686            try:
1687                bstr_obj = self._to_yaml_obj(load(f))
1688            except (CBORDecodeValueError, CBORDecodeEOF):
1689                # failed decoding
1690                bstr_obj = obj.hex()
1691            else:
1692                if f.read(1) != b'':
1693                    # not fully decoded
1694                    bstr_obj = obj.hex()
1695            return {"zcbor_bstr": bstr_obj}
1696        elif isinstance(obj, CBORTag):
1697            return {"zcbor_tag": obj.tag, "zcbor_tag_val": self._to_yaml_obj(obj.value)}
1698        elif obj is undefined:
1699            return ["zcbor_undefined"]
1700        assert not isinstance(obj, bytes)
1701        return obj
1702
1703    # YAML str => CBOR bytestr
1704    def from_yaml(self, yaml_str, yaml_compat=False):
1705        yaml_obj = yaml_load(yaml_str)
1706        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1707        self.validate_obj(obj)
1708        return dumps(obj)
1709
1710    # CBOR object => YAML str
1711    def obj_to_yaml(self, obj, yaml_compat=False):
1712        self.validate_obj(obj)
1713        yaml_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1714        return yaml_dump(yaml_obj)
1715
1716    # CBOR bytestring => YAML str
1717    def str_to_yaml(self, cbor_str, yaml_compat=False):
1718        return self.obj_to_yaml(loads(cbor_str), yaml_compat=yaml_compat)
1719
1720    # JSON str => CBOR bytestr
1721    def from_json(self, json_str, yaml_compat=False):
1722        json_obj = json_load(json_str)
1723        obj = self._from_yaml_obj(json_obj) if yaml_compat else json_obj
1724        self.validate_obj(obj)
1725        return dumps(obj)
1726
1727    # CBOR object => JSON str
1728    def obj_to_json(self, obj, yaml_compat=False):
1729        self.validate_obj(obj)
1730        json_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1731        return json_dump(json_obj)
1732
1733    # CBOR bytestring => JSON str
1734    def str_to_json(self, cbor_str, yaml_compat=False):
1735        return self.obj_to_json(loads(cbor_str), yaml_compat=yaml_compat)
1736
1737    # CBOR bytestring => C code (uint8_t array initialization)
1738    def str_to_c_code(self, cbor_str, var_name, columns=0):
1739        arr = ", ".join(f"0x{c:02x}" for c in cbor_str)
1740        if columns:
1741            arr = '\n' + indent("\n".join(wrap(arr, 6 * columns)), '\t') + '\n'
1742        return f'uint8_t {var_name}[] = {{{arr}}};\n'
1743
1744
1745class XcoderTuple(NamedTuple):
1746    body: list
1747    func_name: str
1748    type_name: str
1749
1750
1751class CddlTypes(NamedTuple):
1752    my_types: dict
1753    my_control_groups: dict
1754
1755
1756# Class for generating C code that encode/decodes CBOR and validates it according
1757# to the CDDL.
1758class CodeGenerator(CddlXcoder):
1759    def __init__(self, mode, entry_type_names, default_bit_size, *args, **kwargs):
1760        super(CodeGenerator, self).__init__(*args, **kwargs)
1761        self.mode = mode
1762        self.entry_type_names = entry_type_names
1763        self.default_bit_size = default_bit_size
1764
1765    @classmethod
1766    def from_cddl(cddl_class, mode, *args, **kwargs):
1767        cddl_res = super(CodeGenerator, cddl_class).from_cddl(*args, **kwargs)
1768
1769        # set access prefix (struct access paths) for all the definitions.
1770        for my_type in cddl_res.my_types:
1771            cddl_res.my_types[my_type].set_access_prefix(f"(*{struct_ptr_name(mode)})")
1772
1773        return cddl_res
1774
1775    # Whether this element (an OTHER) refers to an entry type.
1776    def is_entry_type(self):
1777        return (self.type == "OTHER") and (self.value in self.entry_type_names)
1778
1779    # Whether to include a "cbor" variable for this element.
1780    def is_cbor(self):
1781        res = (self.type_name() is not None) and not self.is_entry_type() and (
1782            (self.type != "OTHER") or self.my_types[self.value].is_cbor())
1783        return res
1784
1785    def init_args(self):
1786        return (self.mode, self.entry_type_names, self.default_bit_size, self.default_max_qty)
1787
1788    # Declaration of the "present" variable for this element.
1789    def present_var(self):
1790        return ["bool %s;" % self.present_var_name()]
1791
1792    # Declaration of the "count" variable for this element.
1793    def count_var(self):
1794        return ["size_t %s;" % self.count_var_name()]
1795
1796    # Declaration of the "choice" variable for this element.
1797    def anonymous_choice_var(self):
1798        int_vals = self.all_children_int_disambiguated()
1799        return self.enclose("enum", [val.enum_var(int_vals) + "," for val in self.value])
1800
1801    # Declaration of the "choice" variable for this element.
1802    def choice_var(self):
1803        var = self.anonymous_choice_var()
1804        var[-1] += f" {self.choice_var_name()};"
1805        return var
1806
1807    # Declaration of the variables of all children.
1808    def child_declarations(self):
1809        decl = [line for child in self.value for line in child.full_declaration()]
1810        return decl
1811
1812    # Declaration of the variables of all children.
1813    def child_single_declarations(self):
1814        decl = list()
1815        for child in self.value:
1816            if not child.is_unambiguous_repeated():
1817                decl.extend(child.add_var_name(
1818                    child.single_var_type(), anonymous=True))
1819        return decl
1820
1821    def simple_func_condition(self):
1822        if self.single_func_impl_condition():
1823            return True
1824        if self.type == "OTHER" and self.my_types[self.value].simple_func_condition():
1825            return True
1826        return False
1827
1828    # Base name if this element needs to declare a type.
1829    def raw_type_name(self):
1830        return "struct %s" % self.id()
1831
1832    def enum_type_name(self):
1833        return "enum %s" % self.id()
1834
1835    # The bit width of the integers as represented in code.
1836    def bit_size(self):
1837        bit_size = None
1838        if self.type in ["UINT", "INT", "NINT"]:
1839            assert self.default_bit_size in [32, 64], "The default_bit_size must be 32 or 64."
1840            if self.default_bit_size == 64:
1841                bit_size = 64
1842            else:
1843                bit_size = 32
1844
1845                for v in [self.value or 0, self.max_value or 0, self.min_value or 0]:
1846                    if (type(v) is str):
1847                        if "64" in v:
1848                            bit_size = 64
1849                    elif self.type == "UINT":
1850                        if (v > UINT32_MAX):
1851                            bit_size = 64
1852                    else:
1853                        if (v > INT32_MAX) or (v < INT32_MIN):
1854                            bit_size = 64
1855        return bit_size
1856
1857    def float_type(self):
1858        if self.type != "FLOAT":
1859            return None
1860
1861        max_size = self.max_size or 8
1862
1863        if max_size <= 4:
1864            return "float"
1865        elif max_size == 8:
1866            return "double"
1867        else:
1868            raise TypeError("Floats must have 4 or 8 bytes of precision.")
1869
1870    # Name of the type of this element's actual value variable.
1871    def val_type_name(self):
1872        if self.multi_val_condition():
1873            return self.raw_type_name()
1874
1875        # Will fail runtime if we don't use lambda for type_name()
1876        # pylint: disable=unnecessary-lambda
1877        name = {
1878            "INT": lambda: f"int{self.bit_size()}_t",
1879            "UINT": lambda: f"uint{self.bit_size()}_t",
1880            "NINT": lambda: f"int{self.bit_size()}_t",
1881            "FLOAT": lambda: self.float_type(),
1882            "BSTR": lambda: "struct zcbor_string",
1883            "TSTR": lambda: "struct zcbor_string",
1884            "BOOL": lambda: "bool",
1885            "NIL": lambda: None,
1886            "UNDEF": lambda: None,
1887            "ANY": lambda: None,
1888            "LIST": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1889            "MAP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1890            "GROUP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1891            "UNION": lambda: self.union_type(),
1892            "OTHER": lambda: self.my_types[self.value].type_name(),
1893        }[self.type]()
1894
1895        return name
1896
1897    # Name of the type of for the repeated part of this element.
1898    # I.e. the part that happens multiple times if the element has a quantifier.
1899    # I.e. not including things like the "count" or "present" variable.
1900    def repeated_type_name(self):
1901        if self.self_repeated_multi_var_condition():
1902            name = self.raw_type_name()
1903            if self.val_type_name() == name:
1904                name = name + "_r"
1905        else:
1906            name = self.val_type_name()
1907        return name
1908
1909    # Name of the type for this element.
1910    def type_name(self):
1911        if self.multi_var_condition():
1912            name = self.raw_type_name()
1913        else:
1914            name = self.repeated_type_name()
1915        return name
1916
1917    # Take a multi member type name and create a variable declaration. Make it an array if the
1918    # element is repeated.
1919    def add_var_name(self, var_type, full=False, anonymous=False):
1920        if var_type:
1921            assert (var_type[-1][-1] == "}" or len(var_type) == 1), \
1922                f"Expected single var: {var_type!r}"
1923            if not anonymous or var_type[-1][-1] != "}":
1924                var_type[-1] += " %s%s" % (
1925                    self.var_name(), "[%s]" % self.max_qty if full and self.max_qty != 1 else "")
1926            var_type = add_semicolon(var_type)
1927        return var_type
1928
1929    # The type for this element as a member variable.
1930    def var_type(self):
1931        if not self.multi_val_condition() and self.val_type_name() is not None:
1932            return [self.val_type_name()]
1933        elif self.type == "UNION":
1934            return self.union_type()
1935        return []
1936
1937    # Enclose a list of declarations in a block (struct, union or enum).
1938    def enclose(self, ingress, declaration):
1939        if declaration:
1940            return [f"{ingress} {{"] + [indentation + line for line in declaration] + ["}"]
1941        else:
1942            return []
1943
1944    # Type declaration for unions.
1945    def union_type(self):
1946        declaration = self.enclose("union", self.child_single_declarations())
1947        return declaration
1948
1949    # Declaration of the repeated part of this element.
1950    def repeated_declaration(self):
1951        if self.is_unambiguous_repeated():
1952            return []
1953
1954        var_type = self.var_type()
1955        multi_var = False
1956
1957        decl = self.add_var_name(var_type, anonymous=(self.type == "UNION"))
1958
1959        if self.type in ["LIST", "MAP", "GROUP"]:
1960            decl += self.child_declarations()
1961            multi_var = len(decl) > 1
1962
1963        if self.reduced_key_var_condition():
1964            key_var = self.key.full_declaration()
1965            decl = key_var + decl
1966            multi_var = key_var != []
1967
1968        if self.choice_var_condition():
1969            choice_var = self.choice_var()
1970            decl += choice_var
1971            multi_var = choice_var != []
1972
1973        if self.cbor_var_condition():
1974            cbor_var = self.cbor.full_declaration()
1975            decl += cbor_var
1976            multi_var = cbor_var != []
1977
1978        # if self.type not in ["LIST", "MAP", "GROUP"] or len(self.value) <= 1:
1979            # This assert isn't accurate for value lists with NIL, "UNDEF", or ANY
1980            # members.
1981            # assert multi_var == self.repeated_multi_var_condition(
1982            # ), f"""assert {multi_var} == {self.repeated_multi_var_condition()}
1983            # type: {self.type}
1984            # decl {decl}
1985            # self.key_var_condition() is {self.key_var_condition()}
1986            # self.key is {self.key}
1987            # self.key.full_declaration() is {self.key.full_declaration()}
1988            # self.cbor_var_condition() is {self.cbor_var_condition()}
1989            # self.choice_var_condition() is {self.choice_var_condition()}
1990            # self.value is {self.value}"""
1991        return decl
1992
1993    # Declaration of the full type for this element.
1994    def full_declaration(self):
1995        multi_var = False
1996
1997        if self.is_unambiguous():
1998            return []
1999
2000        if self.multi_var_condition():
2001            if self.is_unambiguous_repeated():
2002                decl = []
2003            else:
2004                decl = self.add_var_name(
2005                    [self.repeated_type_name()]
2006                    if self.repeated_type_name() is not None else [], full=True)
2007        else:
2008            decl = self.repeated_declaration()
2009
2010        if self.count_var_condition():
2011            count_var = self.count_var()
2012            decl += count_var
2013            multi_var = count_var != []
2014
2015        if self.present_var_condition():
2016            present_var = self.present_var()
2017            decl += present_var
2018            multi_var = present_var != []
2019
2020        assert multi_var == self.multi_var_condition()
2021
2022        return decl
2023
2024    # Return the type definition of this element. If there are multiple variables, wrap them in a
2025    # struct so the function always returns a single type with no name. If full is False, only
2026    # repeated part is used.
2027    def single_var_type(self, full=True):
2028        if full and self.multi_member():
2029            return self.enclose("struct", self.full_declaration())
2030        elif not full and self.repeated_multi_var_condition():
2031            return self.enclose("struct", self.repeated_declaration())
2032        else:
2033            return self.var_type()
2034
2035    # Return the type definition of this element, and all its children + key +
2036    # cbor.
2037    def type_def(self):
2038        ret_val = []
2039        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2040            ret_val.extend(
2041                [elem for typedef in [
2042                    child.type_def() for child in self.value] for elem in typedef])
2043        if self.bits:
2044            ret_val.extend(self.my_control_groups[self.bits].type_def_bits())
2045        if self.cbor_var_condition():
2046            ret_val.extend(self.cbor.type_def())
2047        if self.reduced_key_var_condition():
2048            ret_val.extend(self.key.type_def())
2049        if self.type == "OTHER":
2050            ret_val.extend(self.my_types[self.value].type_def())
2051        if self.repeated_type_def_condition():
2052            type_def_list = self.single_var_type(full=False)
2053            if type_def_list:
2054                ret_val.extend([(self.single_var_type(full=False), self.repeated_type_name())])
2055        if self.type_def_condition():
2056            type_def_list = self.single_var_type()
2057            if type_def_list:
2058                ret_val.extend([(self.single_var_type(), self.type_name())])
2059        return ret_val
2060
2061    def type_def_bits(self):
2062        tdef = self.anonymous_choice_var()
2063        return [(tdef, self.enum_type_name())]
2064
2065    def float_prefix(self):
2066        if self.type != "FLOAT":
2067            return ""
2068
2069        min_size = self.min_size or 2
2070        max_size = self.max_size or 8
2071
2072        if max_size == 2:
2073            return "float16"
2074        elif min_size == 2 and max_size == 4:
2075            return "float16_32" if self.mode == "decode" else "float32"
2076        if min_size == 4 and max_size == 4:
2077            return "float32"
2078        elif min_size == 4 and max_size == 8:
2079            return "float32_64" if self.mode == "decode" else "float64"
2080        elif min_size == 8 and max_size == 8:
2081            return "float64"
2082        elif min_size <= 4 and max_size == 8:
2083            return "float" if self.mode == "decode" else "float64"
2084        else:
2085            raise TypeError("Floats must have 2, 4 or 8 bytes of precision.")
2086
2087    def single_func_prim_prefix(self):
2088        if self.type == "OTHER":
2089            return self.my_types[self.value].single_func_prim_prefix()
2090        return ({
2091            "INT": f"zcbor_int{self.bit_size()}",
2092            "UINT": f"zcbor_uint{self.bit_size()}",
2093            "NINT": f"zcbor_int{self.bit_size()}",
2094            "FLOAT": f"zcbor_{self.float_prefix()}",
2095            "BSTR": f"zcbor_bstr",
2096            "TSTR": f"zcbor_tstr",
2097            "BOOL": f"zcbor_bool",
2098            "NIL": f"zcbor_nil",
2099            "UNDEF": f"zcbor_undefined",
2100            "ANY": f"zcbor_any",
2101        }[self.type])
2102
2103    # Name of the encoder/decoder function for this element.
2104    def xcode_func_name(self):
2105        return f"{self.mode}_{self.var_name(with_prefix=True)}"
2106
2107    # Name of the encoder/decoder function for the repeated part of this element.
2108    def repeated_xcode_func_name(self):
2109        return f"{self.mode}_repeated_{self.var_name(with_prefix=True)}"
2110
2111    def single_func_prim_name(self, union_int=None, ptr_result=False):
2112        """Function name for xcoding this type, when it is a primitive type"""
2113        ptr_variant = ptr_result and self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]
2114        func_prefix = self.single_func_prim_prefix()
2115        if self.mode == "decode":
2116            if self.type == "ANY":
2117                func = "zcbor_any_skip"
2118            elif not self.is_unambiguous_value():
2119                func = f"{func_prefix}_decode"
2120            elif not union_int:
2121                func = f"{func_prefix}_{'pexpect' if ptr_variant else 'expect'}"
2122            elif union_int == "EXPECT":
2123                assert not ptr_variant, \
2124                       "Programmer error: invalid use of expect_union."
2125                func = f"{func_prefix}_expect_union"
2126            elif union_int == "DROP":
2127                return None
2128        else:
2129            if self.type == "ANY":
2130                func = "zcbor_nil_put"
2131            elif (not self.is_unambiguous_value()) or self.type in ["TSTR", "BSTR"] or ptr_variant:
2132                func = f"{func_prefix}_encode"
2133            else:
2134                func = f"{func_prefix}_put"
2135        return func
2136
2137    def single_func_prim(self, access, union_int=None, ptr_result=False):
2138        """Return the function name and arguments to call to encode/decode this element. Only used
2139        when this element DOESN'T define its own encoder/decoder function (when it's a primitive
2140        type, for which functions already exist, or when the function is defined elsewhere
2141        ("OTHER"))
2142        """
2143        assert self.type not in ["LIST", "MAP"], "Must have wrapper function for list or map."
2144
2145        if self.type == "OTHER":
2146            return self.my_types[self.value].single_func(access, union_int)
2147
2148        func_name = self.single_func_prim_name(union_int, ptr_result=ptr_result)
2149        if func_name is None:
2150            return (None, None)
2151
2152        if self.type in ["NIL", "UNDEF", "ANY"]:
2153            arg = "NULL"
2154        elif not self.is_unambiguous_value():
2155            arg = deref_if_not_null(access)
2156        elif self.type in ["BSTR", "TSTR"]:
2157            arg = tmp_str_or_null(self.value)
2158        elif self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]:
2159            value = val_to_str(self.value)
2160            arg = (f"&({self.val_type_name()}){{{value}}}" if ptr_result else value)
2161        else:
2162            assert False, "Should not come here."
2163
2164        min_val = None
2165        max_val = None
2166
2167        if self.value is None:
2168            if self.type in ["BSTR", "TSTR"]:
2169                min_val = self.min_size
2170                max_val = self.max_size
2171            else:
2172                min_val = self.min_value
2173                max_val = self.max_value
2174        return (func_name, arg)
2175
2176    # Return the function name and arguments to call to encode/decode this element.
2177    def single_func(self, access=None, union_int=None):
2178        if self.single_func_impl_condition():
2179            return (self.xcode_func_name(), deref_if_not_null(access or self.var_access()))
2180        else:
2181            return self.single_func_prim(access or self.val_access(), union_int)
2182
2183    # Return the function name and arguments to call to encode/decode the repeated
2184    # part of this element.
2185    def repeated_single_func(self, ptr_result=False):
2186        if self.repeated_single_func_impl_condition():
2187            return (self.repeated_xcode_func_name(), deref_if_not_null(self.repeated_val_access()))
2188        else:
2189            return self.single_func_prim(self.repeated_val_access(), ptr_result=ptr_result)
2190
2191    def has_backup(self):
2192        return (self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"])
2193
2194    def num_backups(self):
2195        total = 0
2196        if self.key:
2197            total += self.key.num_backups()
2198        if self.cbor_var_condition():
2199            total += self.cbor.num_backups()
2200        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2201            total += max([child.num_backups() for child in self.value] + [0])
2202        if self.type == "OTHER":
2203            total += self.my_types[self.value].num_backups()
2204        if self.has_backup():
2205            total += 1
2206        return total
2207
2208    # Return a number indicating how many other elements this element depends on. Used putting
2209    # functions and typedefs in the right order.
2210    def depends_on(self):
2211        ret_vals = [1]
2212
2213        if not self.dependsOnCall:
2214            self.dependsOnCall = True
2215            if self.cbor_var_condition():
2216                ret_vals.append(self.cbor.depends_on())
2217            if self.key:
2218                ret_vals.append(self.key.depends_on())
2219            if self.type == "OTHER":
2220                ret_vals.append(1 + self.my_types[self.value].depends_on())
2221            if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2222                ret_vals.extend(child.depends_on() for child in self.value)
2223            self.dependsOnCall = False
2224
2225        return max(ret_vals)
2226
2227    # Make a string from the list returned by single_func_prim()
2228    def xcode_single_func_prim(self, union_int=None):
2229        return xcode_statement(*self.single_func_prim(self.val_access(), union_int))
2230
2231    # Recursively sum the total minimum and maximum element count for this
2232    # element.
2233    def list_counts(self):
2234        retval = ({
2235            "INT": lambda: (self.min_qty, self.max_qty),
2236            "UINT": lambda: (self.min_qty, self.max_qty),
2237            "NINT": lambda: (self.min_qty, self.max_qty),
2238            "FLOAT": lambda: (self.min_qty, self.max_qty),
2239            "BSTR": lambda: (self.min_qty, self.max_qty),
2240            "TSTR": lambda: (self.min_qty, self.max_qty),
2241            "BOOL": lambda: (self.min_qty, self.max_qty),
2242            "NIL": lambda: (self.min_qty, self.max_qty),
2243            "UNDEF": lambda: (self.min_qty, self.max_qty),
2244            "ANY": lambda: (self.min_qty, self.max_qty),
2245            # Lists are their own element
2246            "LIST": lambda: (self.min_qty, self.max_qty),
2247            # Maps are their own element
2248            "MAP": lambda: (self.min_qty, self.max_qty),
2249            "GROUP": lambda: (self.min_qty * sum((child.list_counts()[0] for child in self.value)),
2250                              self.max_qty * sum((child.list_counts()[1] for child in self.value))),
2251            "UNION": lambda: (self.min_qty * min((child.list_counts()[0] for child in self.value)),
2252                              self.max_qty * max((child.list_counts()[1] for child in self.value))),
2253            "OTHER": lambda: (self.min_qty * self.my_types[self.value].list_counts()[0],
2254                              self.max_qty * self.my_types[self.value].list_counts()[1]),
2255        }[self.type]())
2256        return retval
2257
2258    # Return the full code needed to encode/decode a "LIST" or "MAP" element with children.
2259    def xcode_list(self):
2260        start_func = f"zcbor_{self.type.lower()}_start_{self.mode}"
2261        end_func = f"zcbor_{self.type.lower()}_end_{self.mode}"
2262        end_func_force = f"zcbor_list_map_end_force_{self.mode}"
2263        assert start_func in [
2264            "zcbor_list_start_decode", "zcbor_list_start_encode",
2265            "zcbor_map_start_decode", "zcbor_map_start_encode"]
2266        assert end_func in [
2267            "zcbor_list_end_decode", "zcbor_list_end_encode",
2268            "zcbor_map_end_decode", "zcbor_map_end_encode"]
2269        assert self.type in ["LIST", "MAP"], \
2270            "Expected LIST or MAP type, was %s." % self.type
2271        _, max_counts = zip(
2272            *(child.list_counts() for child in self.value)) if self.value else ((0,), (0,))
2273        count_arg = f', {str(sum(max_counts))}' if self.mode == 'encode' else ''
2274        with_children = "(%s && ((%s) || (%s, false)) && %s)" % (
2275            f"{start_func}(state{count_arg})",
2276            f"{newl_ind}&& ".join(child.full_xcode() for child in self.value),
2277            f"{end_func_force}(state)",
2278            f"{end_func}(state{count_arg})")
2279        without_children = "(%s && %s)" % (
2280            f"{start_func}(state{count_arg})",
2281            f"{end_func}(state{count_arg})")
2282        return with_children if len(self.value) > 0 else without_children
2283
2284    # Return the full code needed to encode/decode a "GROUP" element's children.
2285    def xcode_group(self, union_int=None):
2286        assert self.type in ["GROUP"], "Expected GROUP type."
2287        return "(%s)" % (newl_ind + "&& ").join(
2288            [self.value[0].full_xcode(union_int)]
2289            + [child.full_xcode() for child in self.value[1:]])
2290
2291    # Return the full code needed to encode/decode a "UNION" element's children.
2292    def xcode_union(self):
2293        assert self.type in ["UNION"], "Expected UNION type."
2294        if self.mode == "decode":
2295            if self.all_children_int_disambiguated():
2296                lines = []
2297                lines.extend(
2298                    ["((%s == %s) && (%s))" %
2299                        (self.choice_var_access(), child.enum_var_name(),
2300                            child.full_xcode(union_int="DROP"))
2301                        for child in self.value])
2302                bit_size = self.value[0].bit_size()
2303                func = f"zcbor_uint_{self.mode}" if self.all_children_uint_disambiguated() else \
2304                       f"zcbor_int_{self.mode}"
2305                return "((%s) && (%s))" % (
2306                    f"({func}(state, &{self.choice_var_access()}, "
2307                    + f"sizeof({self.choice_var_access()})))",
2308                    "((" + f"{newl_ind}|| ".join(lines)
2309                         + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_VALUE), false))",)
2310
2311            child_values = ["(%s && ((%s = %s), true))" %
2312                            (child.full_xcode(
2313                                union_int="EXPECT" if child.is_int_disambiguated() else None),
2314                                self.choice_var_access(), child.enum_var_name())
2315                            for child in self.value]
2316
2317            # Reset state for all but the first child.
2318            for i in range(1, len(child_values)):
2319                if ((not self.value[i].is_int_disambiguated())
2320                        and (self.value[i].simple_func_condition()
2321                             or self.value[i - 1].simple_func_condition())):
2322                    child_values[i] = f"(zcbor_union_elem_code(state) && {child_values[i]})"
2323
2324            return "(%s && (int_res = (%s), %s, int_res))" \
2325                % ("zcbor_union_start_code(state)",
2326                   f"{newl_ind}|| ".join(child_values),
2327                   "zcbor_union_end_code(state)")
2328        else:
2329            return ternary_if_chain(
2330                self.choice_var_access(),
2331                [child.enum_var_name() for child in self.value],
2332                [child.full_xcode() for child in self.value])
2333
2334    def xcode_bstr(self):
2335        if self.cbor and not self.cbor.is_entry_type():
2336            access_arg = f', {deref_if_not_null(self.val_access())}' if self.mode == 'decode' \
2337                else ''
2338            res_arg = f', &tmp_str' if self.mode == 'encode' \
2339                else ''
2340            xcode_cbor = "(%s)" % ((newl_ind + "&& ").join(
2341                [f"zcbor_bstr_start_{self.mode}(state{access_arg})",
2342                 f"(int_res = ({self.cbor.full_xcode()}), "
2343                 f"zcbor_bstr_end_{self.mode}(state{res_arg}), int_res)"]))
2344            if self.mode == "decode" or self.is_unambiguous():
2345                return xcode_cbor
2346            else:
2347                return f"({self.val_access()}.value " \
2348                    f"? (memcpy(&tmp_str, &{self.val_access()}, sizeof(tmp_str)), " \
2349                    f"{self.xcode_single_func_prim()}) : ({xcode_cbor}))"
2350        return self.xcode_single_func_prim()
2351
2352    def xcode_tags(self):
2353        return [f"zcbor_tag_{'put' if (self.mode == 'encode') else 'expect'}(state, {tag})"
2354                for tag in self.tags]
2355
2356    # Appends ULL or LL if a value exceeding 32-bits is used
2357    def value_suffix(self, value_str):
2358        if not value_str.isdigit():
2359            return ""
2360        value = int(value_str)
2361        if self.type == "INT" or self.type == "NINT":
2362            if value > INT32_MAX or value <= INT32_MIN:
2363                return "LL"
2364        elif self.type == "UINT":
2365            if value > UINT32_MAX:
2366                return "ULL"
2367
2368        return ""
2369
2370    def range_checks(self, access):
2371        if self.type != "OTHER" and self.value is not None:
2372            return []
2373
2374        range_checks = []
2375
2376        if self.type in ["INT", "UINT", "NINT", "FLOAT", "BOOL"]:
2377            if self.min_value is not None:
2378                range_checks.append(f"({access} >= {val_to_str(self.min_value)}"
2379                                    f"{self.value_suffix(val_to_str(self.min_value))})")
2380            if self.max_value is not None:
2381                range_checks.append(f"({access} <= {val_to_str(self.max_value)}"
2382                                    f"{self.value_suffix(val_to_str(self.max_value))})")
2383            if self.bits:
2384                range_checks.append(
2385                    f"!({access} & ~("
2386                    + ' | '.join([f'(1 << {c.enum_var_name()})'
2387                                 for c in self.my_control_groups[self.bits].value])
2388                    + "))")
2389        elif self.type in ["BSTR", "TSTR"]:
2390            if self.min_size is not None:
2391                range_checks.append(f"({access}.len >= {val_to_str(self.min_size)})")
2392            if self.max_size is not None:
2393                range_checks.append(f"({access}.len <= {val_to_str(self.max_size)})")
2394        elif self.type == "OTHER":
2395            if not self.my_types[self.value].single_func_impl_condition():
2396                range_checks.extend(self.my_types[self.value].range_checks(access))
2397
2398        if range_checks:
2399            range_checks[0] = "((" + range_checks[0]
2400            range_checks[-1] = range_checks[-1] \
2401                + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_RANGE), false))"
2402
2403        return range_checks
2404
2405    # Return the full code needed to encode/decode this element, including children,
2406    # key and cbor, excluding repetitions.
2407    def repeated_xcode(self, union_int=None):
2408        range_checks = self.range_checks(self.val_access())
2409        xcoder = {
2410            "INT": self.xcode_single_func_prim,
2411            "UINT": lambda: self.xcode_single_func_prim(union_int),
2412            "NINT": lambda: self.xcode_single_func_prim(union_int),
2413            "FLOAT": self.xcode_single_func_prim,
2414            "BSTR": self.xcode_bstr,
2415            "TSTR": self.xcode_single_func_prim,
2416            "BOOL": self.xcode_single_func_prim,
2417            "NIL": self.xcode_single_func_prim,
2418            "UNDEF": self.xcode_single_func_prim,
2419            "ANY": self.xcode_single_func_prim,
2420            "LIST": self.xcode_list,
2421            "MAP": self.xcode_list,
2422            "GROUP": lambda: self.xcode_group(union_int),
2423            "UNION": self.xcode_union,
2424            "OTHER": lambda: self.xcode_single_func_prim(union_int),
2425        }[self.type]
2426        xcoders = []
2427        if self.key:
2428            xcoders.append(self.key.full_xcode(union_int))
2429        if self.tags:
2430            xcoders.extend(self.xcode_tags())
2431        if self.mode == "decode":
2432            xcoders.append(xcoder())
2433            xcoders.extend(range_checks)
2434        elif self.type == "BSTR" and self.cbor:
2435            xcoders.append(xcoder())
2436            xcoders.extend(self.range_checks("tmp_str"))
2437        else:
2438            xcoders.extend(range_checks)
2439            xcoders.append(xcoder())
2440
2441        return "(%s)" % ((newl_ind + "&& ").join(xcoders),)
2442
2443    # Code for the size of the repeated part of this element.
2444    def result_len(self):
2445        if self.repeated_type_name() is None or self.is_unambiguous_repeated():
2446            return "0"
2447        else:
2448            return "sizeof(%s)" % self.repeated_type_name()
2449
2450    # Return the full code needed to encode/decode this element, including children,
2451    # key, cbor, and repetitions.
2452    def full_xcode(self, union_int=None):
2453        if self.present_var_condition():
2454            if self.mode == "encode":
2455                func, *arguments = self.repeated_single_func(ptr_result=False)
2456                return f"(!{self.present_var_access()} || {func}({xcode_args(*arguments)}))"
2457            else:
2458                assert self.mode == "decode", \
2459                    f"This code needs self.mode to be 'decode', not {self.mode}."
2460                if not self.repeated_single_func_impl_condition():
2461                    decode_str = self.repeated_xcode(union_int)
2462                    assert "zcbor_" in decode_str, \
2463                        """Must be a direct call to zcbor to guarantee that payload and elem_count"
2464are cleaned up after a failure."""
2465                    return f"({self.present_var_access()} = {self.repeated_xcode(union_int)}, 1)"
2466                func, *arguments = self.repeated_single_func(ptr_result=True)
2467                return (
2468                    f"zcbor_present_decode(&(%s), (zcbor_decoder_t *)%s, %s)" %
2469                    (self.present_var_access(), func, xcode_args(*arguments),))
2470        elif self.count_var_condition():
2471            func, *arguments = self.repeated_single_func(ptr_result=True)
2472            minmax = "_minmax" if self.mode == "encode" else ""
2473            mode = self.mode
2474            return (
2475                f"zcbor_multi_{mode}{minmax}(%s, %s, &%s, (zcbor_{mode}r_t *)%s, %s, %s)" %
2476                (self.min_qty,
2477                 self.max_qty,
2478                 self.count_var_access(),
2479                 func,
2480                 xcode_args(*arguments),
2481                 self.result_len()))
2482        else:
2483            return self.repeated_xcode(union_int)
2484
2485    # Return the body of the encoder/decoder function for this element.
2486    def xcode(self):
2487        return self.full_xcode()
2488
2489    # Recursively return a list of the bodies of the encoder/decoder functions for
2490    # this element and its children + key + cbor.
2491    def xcoders(self):
2492        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2493            for child in self.value:
2494                for xcoder in child.xcoders():
2495                    yield xcoder
2496        if self.cbor:
2497            for xcoder in self.cbor.xcoders():
2498                yield xcoder
2499        if self.key:
2500            for xcoder in self.key.xcoders():
2501                yield xcoder
2502        if self.type == "OTHER" and self.value not in self.entry_type_names:
2503            for xcoder in self.my_types[self.value].xcoders():
2504                yield xcoder
2505        if self.repeated_single_func_impl_condition():
2506            yield XcoderTuple(
2507                self.repeated_xcode(), self.repeated_xcode_func_name(), self.repeated_type_name())
2508        if (self.single_func_impl_condition()):
2509            xcode_body = self.xcode()
2510            yield XcoderTuple(xcode_body, self.xcode_func_name(), self.type_name())
2511
2512    def public_xcode_func_sig(self):
2513        type_name = self.type_name()
2514        return f"""
2515int cbor_{self.xcode_func_name()}(
2516		{"const " if self.mode == "decode" else ""}uint8_t *payload, size_t payload_len,
2517		{"" if self.mode == "decode" else "const "}{type_name
2518            if struct_ptr_name(self.mode) in self.full_xcode() else "void"} *{
2519            struct_ptr_name(self.mode)},
2520		{"size_t *payload_len_out"})"""
2521
2522
2523class CodeRenderer():
2524    def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', file_header=''):
2525        self.entry_types = entry_types
2526        self.print_time = print_time
2527        self.default_max_qty = default_max_qty
2528
2529        self.sorted_types = dict()
2530        self.functions = dict()
2531        self.type_defs = dict()
2532
2533        # Sort type definitions so the typedefs will come in the correct order in the header file
2534        # and the function in the correct order in the c file.
2535        for mode in modes:
2536            self.sorted_types[mode] = list(sorted(
2537                self.entry_types[mode], key=lambda _type: _type.depends_on(), reverse=False))
2538
2539            self.functions[mode] = self.unique_funcs(mode)
2540            self.functions[mode] = self.used_funcs(mode)
2541            self.type_defs[mode] = self.unique_types(mode)
2542
2543        self.version = __version__
2544
2545        if git_sha:
2546            self.version += f'-{git_sha}'
2547
2548        self.file_header = file_header.strip() + "\n\n" if file_header.strip() else ""
2549        self.file_header += f"""Generated using zcbor version {self.version}
2550https://github.com/NordicSemiconductor/zcbor{'''
2551at: ''' + datetime.now().strftime('%Y-%m-%d %H:%M:%S') if self.print_time else ''}
2552Generated with a --default-max-qty of {self.default_max_qty}"""
2553
2554    def header_guard(self, file_name):
2555        return path.basename(file_name).replace(".", "_").replace("-", "_").upper() + "__"
2556
2557    # Return a list of typedefs for all defined types, with duplicate typedefs
2558    # removed.
2559    def unique_types(self, mode):
2560        type_names = {}
2561        out_types = []
2562        for mtype in self.sorted_types[mode]:
2563            for type_def in mtype.type_def():
2564                type_name = type_def[1]
2565                if type_name not in type_names.keys():
2566                    type_names[type_name] = type_def[0]
2567                    out_types.append(type_def)
2568                else:
2569                    assert (''.join(type_names[type_name]) == ''.join(type_def[0])), f"""
2570Two elements share the type name {type_name}, but their implementations are not identical.
2571Please change one or both names. They are
2572{linesep.join(type_names[type_name])}
2573and
2574{linesep.join(type_def[0])}"""
2575        return out_types
2576
2577    # Return a list of encoder/decoder functions for all defined types, with duplicate
2578    # functions removed.
2579    def unique_funcs(self, mode):
2580        func_names = {}
2581        out_types = []
2582        for mtype in self.sorted_types[mode]:
2583            xcoders = list(mtype.xcoders())
2584            for funcType in xcoders:
2585                func_xcode = funcType[0]
2586                func_name = funcType[1]
2587                if func_name not in func_names.keys():
2588                    func_names[func_name] = funcType
2589                    out_types.append(funcType)
2590                elif func_name in func_names.keys():
2591                    assert func_names[func_name][0] == func_xcode, \
2592                        ("Two elements share the function name %s, but their implementations are "
2593                            + "not identical. Please change one or both names.\n\n%s\n\n%s") % \
2594                        (func_name, func_names[func_name][0], func_xcode)
2595
2596        return out_types
2597
2598    # Return a list of encoder/decoder functions for all defined types, with unused
2599    # functions removed.
2600    def used_funcs(self, mode):
2601        mod_entry_types = [
2602            XcoderTuple(
2603                func_type.xcode(),
2604                func_type.xcode_func_name(),
2605                func_type.type_name()) for func_type in self.entry_types[mode]]
2606        out_types = [func_type for func_type in mod_entry_types]
2607        full_code = "".join([func_type[0] for func_type in mod_entry_types])
2608        for func_type in reversed(self.functions[mode]):
2609            func_name = func_type[1]
2610            if func_type not in mod_entry_types and getrp(r"%s\W" % func_name).search(full_code):
2611                full_code += func_type[0]
2612                out_types.append(func_type)
2613        return list(reversed(out_types))
2614
2615    # Render a single decoding function with signature and body.
2616    def render_forward_declaration(self, xcoder, mode):
2617        return f"""
2618static bool {xcoder.func_name}(zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2619            xcoder.type_name
2620            if struct_ptr_name(mode) in xcoder.body else "void"} *{struct_ptr_name(mode)});
2621            """.strip()
2622
2623    def render_function(self, xcoder, mode):
2624        body = xcoder.body
2625        return f"""
2626static bool {xcoder.func_name}(
2627		zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2628            xcoder.type_name
2629            if struct_ptr_name(mode) in body else "void"} *{struct_ptr_name(mode)})
2630{{
2631	zcbor_log("%s\\r\\n", __func__);
2632	{"struct zcbor_string tmp_str;" if "tmp_str" in body else ""}
2633	{"bool int_res;" if "int_res" in body else ""}
2634
2635	bool tmp_result = ({ body });
2636
2637	if (!tmp_result) {{
2638		zcbor_trace_file(state);
2639		zcbor_log("%s error: %s\\r\\n", __func__, zcbor_error_str(zcbor_peek_error(state)));
2640	}} else {{
2641		zcbor_log("%s success\\r\\n", __func__);
2642	}}
2643
2644	return tmp_result;
2645}}""".replace("	\n", "")  # call replace() to remove empty lines.
2646
2647    # Render a single entry function (API function) with signature and body.
2648    def render_entry_function(self, xcoder, mode):
2649        func_name, func_arg = (xcoder.xcode_func_name(), struct_ptr_name(mode))
2650        return f"""
2651{xcoder.public_xcode_func_sig()}
2652{{
2653	zcbor_state_t states[{xcoder.num_backups() + 2}];
2654
2655	return zcbor_entry_function(payload, payload_len, (void *){func_arg}, payload_len_out, states,
2656		(zcbor_decoder_t *){func_name}, sizeof(states) / sizeof(zcbor_state_t), {
2657            xcoder.list_counts()[1]});
2658}}"""
2659
2660    def render_file_header(self, line_prefix):
2661        lp = line_prefix
2662        return (f"\n{lp} " + self.file_header.replace("\n", f"\n{lp} ")).replace(" \n", "\n")
2663
2664    # Render the entire generated C file contents.
2665    def render_c_file(self, header_file_name, mode):
2666        return f"""/*{self.render_file_header(" *")}
2667 */
2668
2669#include <stdint.h>
2670#include <stdbool.h>
2671#include <stddef.h>
2672#include <string.h>
2673#include "zcbor_{mode}.h"
2674#include "{header_file_name}"
2675#include "zcbor_print.h"
2676
2677#if DEFAULT_MAX_QTY != {self.default_max_qty}
2678#error "The type file was generated with a different default_max_qty than this file"
2679#endif
2680
2681{linesep.join([self.render_forward_declaration(xcoder, mode) for xcoder in self.functions[mode]])}
2682
2683{linesep.join([self.render_function(xcoder, mode) for xcoder in self.functions[mode]])}
2684
2685{linesep.join([self.render_entry_function(xcoder, mode) for xcoder in self.entry_types[mode]])}
2686"""
2687
2688    # Render the entire generated header file contents.
2689    def render_h_file(self, type_def_file, header_guard, mode):
2690        return \
2691            f"""/*{self.render_file_header(" *")}
2692 */
2693
2694#ifndef {header_guard}
2695#define {header_guard}
2696
2697#include <stdint.h>
2698#include <stdbool.h>
2699#include <stddef.h>
2700#include <string.h>
2701#include "{type_def_file}"
2702
2703#ifdef __cplusplus
2704extern "C" {{
2705#endif
2706
2707#if DEFAULT_MAX_QTY != {self.default_max_qty}
2708#error "The type file was generated with a different default_max_qty than this file"
2709#endif
2710
2711{(linesep+linesep).join(
2712    [f"{xcoder.public_xcode_func_sig()};" for xcoder in self.entry_types[mode]])}
2713
2714
2715#ifdef __cplusplus
2716}}
2717#endif
2718
2719#endif /* {header_guard} */
2720"""
2721
2722    def render_type_file(self, header_guard, mode):
2723        body = (
2724            linesep + linesep).join(
2725                [f"{typedef[1]} {{{linesep}{linesep.join(typedef[0][1:])};"
2726                    for typedef in self.type_defs[mode]])
2727        return \
2728            f"""/*{self.render_file_header(" *")}
2729 */
2730
2731#ifndef {header_guard}
2732#define {header_guard}
2733
2734#include <stdint.h>
2735#include <stdbool.h>
2736#include <stddef.h>
2737{'#include <zcbor_common.h>' if "struct zcbor_string" in body else ""}
2738
2739#ifdef __cplusplus
2740extern "C" {{
2741#endif
2742
2743/** Which value for --default-max-qty this file was created with.
2744 *
2745 *  The define is used in the other generated file to do a build-time
2746 *  compatibility check.
2747 *
2748 *  See `zcbor --help` for more information about --default-max-qty
2749 */
2750#define DEFAULT_MAX_QTY {self.default_max_qty}
2751
2752{body}
2753
2754#ifdef __cplusplus
2755}}
2756#endif
2757
2758#endif /* {header_guard} */
2759"""
2760
2761    def render_cmake_file(self, target_name, h_files, c_files, type_file,
2762                          output_c_dir, output_h_dir, cmake_dir):
2763        include_dirs = sorted(set(((Path(output_h_dir)),
2764                                  (Path(type_file.name).parent),
2765                                  *((Path(h.name).parent) for h in h_files.values()))))
2766
2767        def relativify(p):
2768            return PurePosixPath(
2769                Path("${CMAKE_CURRENT_LIST_DIR}") / path.relpath(Path(p), cmake_dir))
2770        return \
2771            f"""\
2772#{self.render_file_header("#")}
2773#
2774
2775add_library({target_name})
2776target_sources({target_name} PRIVATE
2777    {relativify(Path(output_c_dir, "zcbor_decode.c"))}
2778    {relativify(Path(output_c_dir, "zcbor_encode.c"))}
2779    {relativify(Path(output_c_dir, "zcbor_common.c"))}
2780    {(linesep + "    ").join(((str(relativify(c.name))) for c in c_files.values()))}
2781    )
2782target_include_directories({target_name} PUBLIC
2783    {(linesep + "    ").join(((str(relativify(f)) for f in include_dirs)))}
2784    )
2785"""
2786
2787    def render(self, modes, h_files, c_files, type_file, include_prefix, cmake_file=None,
2788               output_c_dir=None, output_h_dir=None):
2789        for mode in modes:
2790            h_name = Path(include_prefix, Path(h_files[mode].name).name)
2791
2792            # Create and populate the generated c and h file.
2793            makedirs("./" + path.dirname(c_files[mode].name), exist_ok=True)
2794
2795            type_def_name = Path(include_prefix, Path(type_file.name).name)
2796
2797            print("Writing to " + c_files[mode].name)
2798            c_files[mode].write(self.render_c_file(h_name, mode))
2799
2800            print("Writing to " + h_files[mode].name)
2801            h_files[mode].write(self.render_h_file(
2802                type_def_name,
2803                self.header_guard(h_files[mode].name), mode))
2804
2805        print("Writing to " + type_file.name)
2806        type_file.write(self.render_type_file(self.header_guard(type_file.name), mode))
2807
2808        if cmake_file:
2809            print("Writing to " + cmake_file.name)
2810            cmake_file.write(self.render_cmake_file(
2811                Path(cmake_file.name).stem, h_files, c_files, type_file,
2812                output_c_dir, output_h_dir, Path(cmake_file.name).absolute().parent))
2813
2814
2815def int_or_str(arg):
2816    try:
2817        return int(arg)
2818    except ValueError:
2819        # print(arg)
2820        if getrp(r"\A\w+\Z").match(arg) is not None:
2821            return arg
2822    raise ArgumentTypeError(
2823        "Argument must be an integer or a string with only letters, numbers, or '_'.")
2824
2825
2826def parse_args():
2827
2828    parent_parser = ArgumentParser(add_help=False)
2829
2830    parent_parser.add_argument(
2831        "-c", "--cddl", required=True, type=FileType('r'), action="append",
2832        help="""Path to one or more input CDDL file(s). Passing multiple files is equivalent to
2833concatenating them.""")
2834    parent_parser.add_argument(
2835        "--no-prelude", required=False, action="store_true", default=False,
2836        help=f"""Exclude the standard CDDL prelude from the build. The prelude can be viewed at
2837{PRELUDE_PATH.relative_to(PACKAGE_PATH)} in the repo, or together with the script.""")
2838    parent_parser.add_argument(
2839        "-v", "--verbose", required=False, action="store_true", default=False,
2840        help="Print more information while parsing CDDL and generating code.")
2841
2842    parser = ArgumentParser(
2843        description='''Parse a CDDL file and validate/convert between YAML, JSON, and CBOR.
2844Can also generate C code for validation/encoding/decoding of CBOR.''')
2845
2846    parser.add_argument(
2847        "--version", action="version", version=f"zcbor {__version__}")
2848
2849    subparsers = parser.add_subparsers()
2850    code_parser = subparsers.add_parser(
2851        "code", description='''Parse a CDDL file and produce C code that validates and xcodes CBOR.
2852The output from this script is a C file and a header file. The header file
2853contains typedefs for all the types specified in the cddl input file, as well
2854as declarations to xcode functions for the types designated as entry types when
2855running the script. The c file contains all the code for decoding and validating
2856the types in the CDDL input file. All types are validated as they are xcoded.
2857
2858Where a `bstr .cbor <Type>` is specified in the CDDL, AND the Type is an entry
2859type, the xcoder will not xcode the string, only provide a pointer into the
2860payload buffer. This is useful to reduce the size of typedefs, or to break up
2861decoding. Using this mechanism is necessary when the CDDL contains self-
2862referencing types, since the C type cannot be self referencing.
2863
2864This script requires 'regex' for lookaround functionality not present in 're'.''',
2865        formatter_class=RawDescriptionHelpFormatter,
2866        parents=[parent_parser])
2867
2868    code_parser.add_argument(
2869        "--default-max-qty", "--dq", required=False, type=int_or_str, default=3,
2870        help="""Default maximum number of repetitions when no maximum
2871is specified. This is needed to construct complete C types.
2872
2873The default_max_qty can usually be set to a text symbol if desired,
2874to allow it to be configurable when building the code. This is not always
2875possible, as sometimes the value is needed for internal computations.
2876If so, the script will raise an exception.""")
2877    code_parser.add_argument(
2878        "--output-c", "--oc", required=False, type=str,
2879        help="""Path to output C file. If both --decode and --encode are specified, _decode and
2880_encode will be appended to the filename when creating the two files. If not
2881specified, the path and name will be based on the --output-cmake file. A 'src'
2882directory will be created next to the cmake file, and the C file will be
2883placed there with the same name (except the file extension) as the cmake file.""")
2884    code_parser.add_argument(
2885        "--output-h", "--oh", required=False, type=str,
2886        help="""Path to output header file. If both --decode and --encode are specified, _decode and
2887_encode will be appended to the filename when creating the two files. If not
2888specified, the path and name will be based on the --output-cmake file. An 'include'
2889directory will be created next to the cmake file, and the C file will be
2890placed there with the same name (except the file extension) as the cmake file.""")
2891    code_parser.add_argument(
2892        "--output-h-types", "--oht", required=False, type=str,
2893        help="""Path to output header file with typedefs (shared between decode and encode).
2894If not specified, the path and name will be taken from the output header file
2895(--output-h), with '_types' added to the file name.""")
2896    code_parser.add_argument(
2897        "--copy-sources", required=False, action="store_true", default=False,
2898        help="""Copy the non-generated source files (zcbor_*.c/h) into the same directories as the
2899generated files.""")
2900    code_parser.add_argument(
2901        "--output-cmake", required=False, type=str,
2902        help="""Path to output CMake file. The filename of the CMake file without '.cmake' is used
2903as the name of the CMake target in the file.
2904The CMake file defines a CMake target with the zcbor source files and the
2905generated file as sources, and the zcbor header files' and generated header
2906files' folders as include_directories.
2907Add it to your project via include() in your CMakeLists.txt file, and link the
2908target to your program.
2909This option works with or without the --copy-sources option.""")
2910    code_parser.add_argument(
2911        "-t", "--entry-types", required=True, type=str, nargs="+",
2912        help="Names of the types which should have their xcode functions exposed.")
2913    code_parser.add_argument(
2914        "-d", "--decode", required=False, action="store_true", default=False,
2915        help="Generate decoding code. Either --decode or --encode or both must be specified.")
2916    code_parser.add_argument(
2917        "-e", "--encode", required=False, action="store_true", default=False,
2918        help="Generate encoding code. Either --decode or --encode or both must be specified.")
2919    code_parser.add_argument(
2920        "--time-header", required=False, action="store_true", default=False,
2921        help="Put the current time in a comment in the generated files.")
2922    code_parser.add_argument(
2923        "--git-sha-header", required=False, action="store_true", default=False,
2924        help="Put the current git sha of zcbor in a comment in the generated files.")
2925    code_parser.add_argument(
2926        "-b", "--default-bit-size", required=False, type=int, default=32, choices=[32, 64],
2927        help="""Default bit size of integers in code. When integers have no explicit bounds,
2928assume they have this bit width. Should follow the bit width of the architecture
2929the code will be running on.""")
2930    code_parser.add_argument(
2931        "--include-prefix", default="",
2932        help="""When #include'ing generated files, add this path prefix to the filename.""")
2933    code_parser.add_argument(
2934        "-s", "--short-names", required=False, action="store_true", default=False,
2935        help="""Attempt to make most generated struct member names shorter. This might make some
2936names identical which will cause a compile error. If so, tweak the CDDL labels
2937or layout, or disable this option. This might also make enum names different
2938from the corresponding union members.""")
2939    code_parser.add_argument(
2940        "--file-header", required=False, type=str, default="",
2941        help="Header to be included in the comment at the top of generated C files, e.g. copyright."
2942    )
2943    code_parser.set_defaults(process=process_code)
2944
2945    validate_parent_parser = ArgumentParser(add_help=False)
2946    validate_parent_parser.add_argument(
2947        "-i", "--input", required=True, type=str,
2948        help='''Input data file. The option --input-as specifies how to interpret the contents.
2949Use "-" to indicate stdin.''')
2950    validate_parent_parser.add_argument(
2951        "--input-as", required=False, choices=["yaml", "json", "cbor", "cborhex"],
2952        help='''Which format to interpret the input file as.
2953If omitted, the format is inferred from the file name.
2954.yaml, .yml => YAML, .json => JSON, .cborhex => CBOR as hex string, everything else => CBOR''')
2955    validate_parent_parser.add_argument(
2956        "-t", "--entry-type", required=True, type=str,
2957        help='''Name of the type (from the CDDL) to interpret the data as.''')
2958    validate_parent_parser.add_argument(
2959        "--default-max-qty", "--dq", required=False, type=int, default=0xFFFFFFFF,
2960        help="""Default maximum number of repetitions when no maximum is specified.
2961It is only relevant when handling data that will be decoded by generated code.
2962If omitted, a large number will be used.""")
2963    validate_parent_parser.add_argument(
2964        "--yaml-compatibility", required=False, action="store_true", default=False,
2965        help='''Whether to convert CBOR-only values to YAML-compatible ones
2966(when converting from CBOR), or vice versa (when converting to CBOR).
2967
2968When this is enabled, all CBOR data is guaranteed to convert into YAML/JSON.
2969JSON and YAML do not support all data types that CBOR/CDDL supports.
2970bytestrings (BSTR), tags, undefined, and maps with non-text keys need
2971special handling. See the zcbor README for more information.''')
2972
2973    validate_parser = subparsers.add_parser(
2974        "validate", description='''Read CBOR, YAML, or JSON data from file or stdin and validate
2975it against a CDDL schema file.
2976        ''',
2977        parents=[parent_parser, validate_parent_parser])
2978
2979    validate_parser.set_defaults(process=process_validate)
2980
2981    convert_parser = subparsers.add_parser(
2982        "convert", description='''Parse a CDDL file and validate/convert between CBOR and YAML/JSON.
2983The script decodes the CBOR/YAML/JSON data from a file or stdin
2984and verifies that it conforms to the CDDL description.
2985The script fails if the data does not conform.
2986'zcbor validate' can be used if only validate is needed.''',
2987        parents=[parent_parser, validate_parent_parser])
2988
2989    convert_parser.add_argument(
2990        "-o", "--output", required=True, type=str,
2991        help='''Output data file. The option --output-as specifies how to interpret the contents.
2992 Use "-" to indicate stdout.''')
2993    convert_parser.add_argument(
2994        "--output-as", required=False, choices=["yaml", "json", "cbor", "cborhex", "c_code"],
2995        help='''Which format to interpret the output file as.
2996If omitted, the format is inferred from the file name.
2997.yaml, .yml => YAML, .json => JSON, .c, .h => C code,
2998.cborhex => CBOR as hex string, everything else => CBOR''')
2999    convert_parser.add_argument(
3000        "--c-code-var-name", required=False, type=str,
3001        help='''Only relevant together with '--output-as c_code' or .c files.''')
3002    convert_parser.add_argument(
3003        "--c-code-columns", required=False, type=int, default=0,
3004        help='''Only relevant together with '--output-as c_code' or .c files.
3005The number of bytes per line in the variable instantiation. If omitted, the
3006entire declaration is a single line.''')
3007    convert_parser.set_defaults(process=process_convert)
3008
3009    args = parser.parse_args()
3010
3011    if not args.no_prelude:
3012        args.cddl.append(open(PRELUDE_PATH, 'r', encoding="utf-8"))
3013
3014    if hasattr(args, "decode") and not args.decode and not args.encode:
3015        parser.error("Please specify at least one of --decode or --encode.")
3016
3017    if hasattr(args, "output_c"):
3018        if not args.output_c or not args.output_h:
3019            if not args.output_cmake:
3020                parser.error(
3021                    "Please specify both --output-c and --output-h "
3022                    "unless --output-cmake is specified.")
3023
3024    return args
3025
3026
3027def process_code(args):
3028    modes = list()
3029    if args.decode:
3030        modes.append("decode")
3031    if args.encode:
3032        modes.append("encode")
3033
3034    print("Parsing files: " + ", ".join((c.name for c in args.cddl)))
3035
3036    cddl_contents = linesep.join((c.read() for c in args.cddl))
3037
3038    cddl_res = dict()
3039    for mode in modes:
3040        cddl_res[mode] = CodeGenerator.from_cddl(
3041            mode, cddl_contents, args.default_max_qty, mode, args.entry_types,
3042            args.default_bit_size, short_names=args.short_names)
3043
3044    # Parsing is done, pretty print the result.
3045    verbose_print(args.verbose, "Parsed CDDL types:")
3046    for mode in modes:
3047        verbose_pprint(args.verbose, cddl_res[mode].my_types)
3048
3049    git_sha = ''
3050    if args.git_sha_header:
3051        if "zcbor.py" in sys.argv[0]:
3052            git_args = ['git', 'rev-parse', '--verify', '--short', 'HEAD']
3053            git_sha = Popen(
3054                git_args, cwd=PACKAGE_PATH, stdout=PIPE).communicate()[0].decode('utf-8').strip()
3055        else:
3056            git_sha = __version__
3057
3058    def create_and_open(path, mode='w'):
3059        Path(path).absolute().parent.mkdir(parents=True, exist_ok=True)
3060        return Path(path).open(mode)
3061
3062    if args.output_cmake:
3063        cmake_dir = Path(args.output_cmake).parent
3064        output_cmake = create_and_open(args.output_cmake)
3065        filenames = Path(args.output_cmake).parts[-1].replace(".cmake", "")
3066    else:
3067        output_cmake = None
3068
3069    def add_mode_to_fname(filename, mode):
3070        name = Path(filename).stem + "_" + mode + Path(filename).suffix
3071        return Path(filename).with_name(name)
3072
3073    output_c = dict()
3074    output_h = dict()
3075    out_c = args.output_c if (len(modes) == 1 and args.output_c) else None
3076    out_h = args.output_h if (len(modes) == 1 and args.output_h) else None
3077    for mode in modes:
3078        output_c[mode] = create_and_open(
3079            out_c or add_mode_to_fname(
3080                args.output_c or Path(cmake_dir, 'src', f'{filenames}.c'), mode))
3081        output_h[mode] = create_and_open(
3082            out_h or add_mode_to_fname(
3083                args.output_h or Path(cmake_dir, 'include', f'{filenames}.h'), mode))
3084
3085    out_c_parent = Path(output_c[modes[0]].name).parent
3086    out_h_parent = Path(output_h[modes[0]].name).parent
3087
3088    output_h_types = create_and_open(
3089        args.output_h_types
3090        or (args.output_h and Path(args.output_h).with_name(Path(args.output_h).stem + "_types.h"))
3091        or Path(cmake_dir, 'include', filenames + '_types.h'))
3092
3093    renderer = CodeRenderer(entry_types={mode: [cddl_res[mode].my_types[entry]
3094                                         for entry in args.entry_types] for mode in modes},
3095                            modes=modes, print_time=args.time_header,
3096                            default_max_qty=args.default_max_qty, git_sha=git_sha,
3097                            file_header=args.file_header
3098                            )
3099
3100    c_code_dir = C_SRC_PATH
3101    h_code_dir = C_INCLUDE_PATH
3102
3103    if args.copy_sources:
3104        new_c_code_dir = out_c_parent
3105        new_h_code_dir = out_h_parent
3106        copyfile(Path(c_code_dir, "zcbor_decode.c"), Path(new_c_code_dir, "zcbor_decode.c"))
3107        copyfile(Path(c_code_dir, "zcbor_encode.c"), Path(new_c_code_dir, "zcbor_encode.c"))
3108        copyfile(Path(c_code_dir, "zcbor_common.c"), Path(new_c_code_dir, "zcbor_common.c"))
3109        copyfile(Path(h_code_dir, "zcbor_decode.h"), Path(new_h_code_dir, "zcbor_decode.h"))
3110        copyfile(Path(h_code_dir, "zcbor_encode.h"), Path(new_h_code_dir, "zcbor_encode.h"))
3111        copyfile(Path(h_code_dir, "zcbor_common.h"), Path(new_h_code_dir, "zcbor_common.h"))
3112        copyfile(Path(h_code_dir, "zcbor_tags.h"), Path(new_h_code_dir, "zcbor_tags.h"))
3113        copyfile(Path(h_code_dir, "zcbor_print.h"), Path(new_h_code_dir, "zcbor_print.h"))
3114        c_code_dir = new_c_code_dir
3115        h_code_dir = new_h_code_dir
3116
3117    renderer.render(modes, output_h, output_c, output_h_types, args.include_prefix,
3118                    output_cmake, c_code_dir, h_code_dir)
3119
3120
3121def parse_cddl(args):
3122    cddl_contents = linesep.join((c.read() for c in args.cddl))
3123    cddl_res = DataTranslator.from_cddl(cddl_contents, args.default_max_qty)
3124    return cddl_res.my_types[args.entry_type]
3125
3126
3127def read_data(args, cddl):
3128    _, in_file_ext = path.splitext(args.input)
3129    in_file_format = args.input_as or in_file_ext.strip(".")
3130    if in_file_format in ["yaml", "yml"]:
3131        f = sys.stdin if args.input == "-" else open(args.input, "r")
3132        cbor_str = cddl.from_yaml(f.read(), yaml_compat=args.yaml_compatibility)
3133    elif in_file_format == "json":
3134        f = sys.stdin if args.input == "-" else open(args.input, "r")
3135        cbor_str = cddl.from_json(f.read(), yaml_compat=args.yaml_compatibility)
3136    elif in_file_format == "cborhex":
3137        f = sys.stdin if args.input == "-" else open(args.input, "r")
3138        cbor_str = bytes.fromhex(f.read().replace("\n", ""))
3139        cddl.validate_str(cbor_str)
3140    else:
3141        f = sys.stdin.buffer if args.input == "-" else open(args.input, "rb")
3142        cbor_str = f.read()
3143        cddl.validate_str(cbor_str)
3144
3145    return cbor_str
3146
3147
3148def write_data(args, cddl, cbor_str):
3149    _, out_file_ext = path.splitext(args.output)
3150    out_file_format = args.output_as or out_file_ext.strip(".")
3151    if out_file_format in ["yaml", "yml"]:
3152        f = sys.stdout if args.output == "-" else open(args.output, "w")
3153        f.write(cddl.str_to_yaml(cbor_str, yaml_compat=args.yaml_compatibility))
3154    elif out_file_format == "json":
3155        f = sys.stdout if args.output == "-" else open(args.output, "w")
3156        f.write(cddl.str_to_json(cbor_str, yaml_compat=args.yaml_compatibility))
3157    elif out_file_format in ["c", "h", "c_code"]:
3158        f = sys.stdout if args.output == "-" else open(args.output, "w")
3159        assert args.c_code_var_name is not None, \
3160            "Must specify --c-code-var-name when outputting c code."
3161        f.write(cddl.str_to_c_code(cbor_str, args.c_code_var_name, args.c_code_columns))
3162    elif out_file_format == "cborhex":
3163        f = sys.stdout if args.output == "-" else open(args.output, "w")
3164        f.write(getrp(r"(.{1,64})").sub(r"\1\n", cbor_str.hex()))  # Add newlines every 64 chars
3165    else:
3166        f = sys.stdout.buffer if args.output == "-" else open(args.output, "wb")
3167        f.write(cbor_str)
3168
3169
3170def process_validate(args):
3171    cddl = parse_cddl(args)
3172    read_data(args, cddl)
3173
3174
3175def process_convert(args):
3176    cddl = parse_cddl(args)
3177    cbor_str = read_data(args, cddl)
3178    write_data(args, cddl, cbor_str)
3179
3180
3181def main():
3182    args = parse_args()
3183    args.process(args)
3184
3185
3186if __name__ == "__main__":
3187    main()
3188