1#!/usr/bin/env python3
2#
3# Copyright (c) 2020 Nordic Semiconductor ASA
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7
8from regex import match, search, sub, findall, S, M, fullmatch
9from pprint import pformat, pprint
10from os import path, linesep, makedirs
11from collections import defaultdict, namedtuple
12from typing import NamedTuple
13from argparse import ArgumentParser, ArgumentTypeError, RawDescriptionHelpFormatter, FileType
14from datetime import datetime
15from copy import copy
16from itertools import tee
17from cbor2 import (loads, dumps, CBORTag, load, CBORDecodeValueError, CBORDecodeEOF, undefined,
18                   CBORSimpleValue)
19from yaml import safe_load as yaml_load, dump as yaml_dump
20from json import loads as json_load, dumps as json_dump
21from io import BytesIO
22from subprocess import Popen, PIPE
23from pathlib import Path, PurePath
24from shutil import copyfile
25import sys
26from site import USER_BASE
27from textwrap import wrap, indent
28
29indentation = "\t"
30newl_ind = "\n" + indentation
31
32P_SCRIPT = Path(__file__).absolute().parent
33P_REPO_ROOT = Path(__file__).absolute().parents[1]
34VERSION_path = Path(P_SCRIPT, "VERSION")
35PRELUDE_path = Path(P_SCRIPT, "cddl", "prelude.cddl")
36
37__version__ = VERSION_path.read_text().strip()
38
39UINT8_MAX = 0xFF
40UINT16_MAX = 0xFFFF
41UINT32_MAX = 0xFFFFFFFF
42UINT64_MAX = 0xFFFFFFFFFFFFFFFF
43
44INT8_MAX = 0x7F
45INT16_MAX = 0x7FFF
46INT32_MAX = 0x7FFFFFFF
47INT64_MAX = 0x7FFFFFFFFFFFFFFF
48
49INT8_MIN = -0x80
50INT16_MIN = -0x8000
51INT32_MIN = -0x80000000
52INT64_MIN = -0x8000000000000000
53
54
55def is_relative_to(path1, path2):
56    try:
57        path1.relative_to(path2)
58    except ValueError:
59        return False
60    return True
61
62
63# The root of the non-generated c code (<c_code_root>/src and <c_code_root>/include)
64if Path(__file__).name in sys.argv[0]:
65    # Running the script directly in the repo.
66    c_code_root = P_REPO_ROOT
67elif any((match(r"zcbor-.*\.egg", p) for p in P_SCRIPT.parts)):
68    # Installed via setup.py install
69    c_code_root = Path(P_SCRIPT.parent, "lib", "zcbor")
70elif is_relative_to(P_SCRIPT, (Path(sys.prefix, "local"))):
71    # Installed via pip as root.
72    c_code_root = Path(sys.prefix, "local", "lib", "zcbor")
73elif is_relative_to(P_SCRIPT, (Path(sys.prefix))):
74    # Installed via pip as root.
75    c_code_root = Path(sys.prefix, "lib", "zcbor")
76elif is_relative_to(P_SCRIPT, (Path(USER_BASE))):
77    # Installed via pip as user.
78    c_code_root = Path(USER_BASE, "lib", "zcbor")
79else:
80    # Don't know where the C code is. Assume we are in the repo.
81    c_code_root = P_REPO_ROOT
82
83
84# Size of "additional" field if num is encoded as int
85def sizeof(num):
86    if num <= 23:
87        return 0
88    elif num <= UINT8_MAX:
89        return 1
90    elif num <= UINT16_MAX:
91        return 2
92    elif num <= UINT32_MAX:
93        return 4
94    elif num <= UINT64_MAX:
95        return 8
96    else:
97        raise ValueError("Number too large (more than 64 bits).")
98
99
100# Print only if verbose
101def verbose_print(verbose_flag, *things):
102    if verbose_flag:
103        print(*things)
104
105
106# Pretty print only if verbose
107def verbose_pprint(verbose_flag, *things):
108    if verbose_flag:
109        pprint(*things)
110
111
112global_counter = 0
113
114
115# Retrieve a unique id.
116def counter(reset=False):
117    global global_counter
118    if reset:
119        global_counter = 0
120        return global_counter
121    global_counter += 1
122    return global_counter
123
124
125# Replace an element in a list or tuple and return the list. For use in
126# lambdas.
127def list_replace_if_not_null(lst, i, r):
128    if lst[i] == "NULL":
129        return lst
130    if isinstance(lst, tuple):
131        convert = tuple
132        lst = list(lst)
133    else:
134        assert isinstance(lst, list)
135        convert = list
136    lst[i] = r
137    return convert(lst)
138
139
140# Return a code snippet that assigns the value to a variable var_name and
141# returns pointer to the variable, or returns NULL if the value is None.
142def val_or_null(value, var_name):
143    return "(%s = %d, &%s)" % (var_name, value, var_name) if value is not None else "NULL"
144
145
146# Assign the min_value variable.
147def tmp_str_or_null(value):
148    value_str = f'"{value}"' if value is not None else 'NULL'
149    len_str = f"""sizeof({f'"{value}"'}) - 1, &tmp_str)"""
150    return f"(tmp_str.value = (uint8_t *){value_str}, tmp_str.len = {len_str}"
151
152
153# Assign the max_value variable.
154def min_bool_or_null(value):
155    return f"(&(bool){{{int(value)}}})"
156
157
158def deref_if_not_null(access):
159    return access if access == "NULL" else "&" + access
160
161
162# Return an argument list for a function call to a encoder/decoder function.
163def xcode_args(res, *sargs):
164    if len(sargs) > 0:
165        return "state, %s, %s, %s" % (
166            "&(%s)" % res if res != "NULL" else res, sargs[0], sargs[1])
167    else:
168        return "state, %s" % (
169            "(%s)" % res if res != "NULL" else res)
170
171
172# Return the code that calls a encoder/decoder function with a given set of
173# arguments.
174def xcode_statement(func, *sargs, **kwargs):
175    if func is None:
176        return "1"
177    return "(%s(%s))" % (func, xcode_args(*sargs, **kwargs))
178
179
180def add_semicolon(decl):
181    if len(decl) != 0 and decl[-1][-1] != ";":
182        decl[-1] += ";"
183    return decl
184
185
186def struct_ptr_name(mode):
187    return "result" if mode == "decode" else "input"
188
189
190def ternary_if_chain(access, names, xcode_strings):
191    return "((%s == %s) ? %s%s: %s)" % (
192        access,
193        names[0],
194        xcode_strings[0],
195        newl_ind,
196        ternary_if_chain(access, names[1:], xcode_strings[1:]) if len(names) > 1 else "false")
197
198
199# Class for parsing CDDL. One instance represents one CBOR data item, with a few caveats:
200#  - For repeated data, one instance represents all repetitions.
201#  - For "OTHER" types, one instance points to another type definition.
202#  - For "GROUP" and "UNION" types, there is no separate data item for the instance.
203class CddlParser:
204    def __init__(self, default_max_qty, my_types, my_control_groups, base_name=None,
205                 short_names=False, base_stem=''):
206        self.id_prefix = "temp_" + str(counter())
207        self.id_num = None  # Unique ID number. Only populated if needed.
208        # The value of the data item. Has different meaning for different
209        # types.
210        self.value = None
211        self.max_value = None  # Maximum value. Only used for numbers and bools.
212        self.min_value = None  # Minimum value. Only used for numbers and bools.
213        # The readable label associated with the element in the CDDL.
214        self.label = None
215        self.min_qty = 1  # The minimum number of times this element is repeated.
216        self.max_qty = 1  # The maximum number of times this element is repeated.
217        # The size of the element. Only used for integers, byte strings, and
218        # text strings.
219        self.size = None
220        self.min_size = None  # Minimum size.
221        self.max_size = None  # Maximum size.
222        # Key element. Only for children of "MAP" elements. self.key is of the
223        # same class as self.
224        self.key = None
225        # The element specified via.cbor or.cborseq(only for byte
226        # strings).self.cbor is of the same class as self.
227        self.cbor = None
228        # Any tags (type 6) to precede the element.
229        self.tags = []
230        # The CDDL string used to determine the min_qty and max_qty. Not used after
231        # min_qty and max_qty are determined.
232        self.quantifier = None
233        # Sockets are types starting with "$" or "$$". Do not fail if these aren't defined.
234        self.is_socket = False
235        # If the type has a ".bits <group_name>", this will contain <group_name> which can be looked
236        # up in my_control_groups.
237        self.bits = None
238        # The "type" of the element. This follows the CBOR types loosely, but are more related to
239        # CDDL concepts. The possible types are "INT", "UINT", "NINT", "FLOAT", "BSTR", "TSTR",
240        # "BOOL", "NIL", "UNDEF", "LIST", "MAP","GROUP", "UNION" and "OTHER". "OTHER" represents a
241        # CDDL type defined with '='.
242        self.type = None
243        self.match_str = ""
244        self.errors = list()
245
246        self.my_types = my_types
247        self.my_control_groups = my_control_groups
248        self.default_max_qty = default_max_qty  # args.default_max_qty
249        self.base_name = base_name  # Used as default for self.get_base_name()
250        # Stem which can be used when generating an id.
251        self.base_stem = base_stem.replace("-", "_")
252        self.short_names = short_names
253
254    @classmethod
255    def from_cddl(cddl_class, cddl_string, default_max_qty, *args, **kwargs):
256        my_types = dict()
257
258        type_strings = cddl_class.get_types(cddl_string)
259        # Separate type_strings as keys in two dicts, one dict for strings that start with &( which
260        # are special control operators for .bits, and one dict for all the regular types.
261        my_types = \
262            {my_type: None for my_type, val in type_strings.items() if not val.startswith("&(")}
263        my_control_groups = \
264            {my_cg: None for my_cg, val in type_strings.items() if val.startswith("&(")}
265
266        # Parse the definitions, replacing the each string with a
267        # CodeGenerator instance.
268        for my_type, cddl_string in type_strings.items():
269            parsed = cddl_class(*args, default_max_qty, my_types, my_control_groups, **kwargs,
270                                base_stem=my_type)
271            parsed.get_value(cddl_string.replace("\n", " ").lstrip("&"))
272            parsed = parsed.flatten()[0]
273            if my_type in my_types:
274                my_types[my_type] = parsed
275            elif my_type in my_control_groups:
276                my_control_groups[my_type] = parsed
277
278        counter(True)
279
280        # post_validate all the definitions.
281        for my_type in my_types:
282            my_types[my_type].set_id_prefix()
283            my_types[my_type].post_validate()
284            my_types[my_type].set_base_names()
285        for my_control_group in my_control_groups:
286            my_control_groups[my_control_group].set_id_prefix()
287            my_control_groups[my_control_group].post_validate_control_group()
288
289        return CddlTypes(my_types, my_control_groups)
290
291    # Strip CDDL comments (';') from the string.
292    @staticmethod
293    def strip_comments(instr):
294        comment_regex = r"\;.*?\n"
295        return sub(comment_regex, '', instr)
296
297    @staticmethod
298    def resolve_backslashes(instr):
299        return sub(r"\\\n", " ", instr)
300
301    # Returns a dict containing multiple typename=>string
302    @classmethod
303    def get_types(cls, cddl_string):
304        instr = cls.strip_comments(cddl_string)
305        instr = cls.resolve_backslashes(instr)
306        type_regex = \
307            r"(\s*?\$?\$?([\w-]+)\s*(\/{0,2})=\s*(.*?)(?=(\Z|\s*\$?\$?[\w-]+\s*\/{0,2}=(?!\>))))"
308        result = defaultdict(lambda: "")
309        types = [
310            (key, value, slashes)
311            for (_1, key, slashes, value, _2) in findall(type_regex, instr, S | M)]
312        for key, value, slashes in types:
313            if slashes:
314                result[key] += slashes
315                result[key] += value
316                result[key] = result[key].lstrip(slashes)  # strip from front
317            else:
318                if key in result:
319                    raise ValueError(f"Duplicate CDDL type found: {key}")
320                result[key] = value
321        return dict(result)
322
323    backslash_quotation_mark = r'\"'
324
325    # Generate a (hopefully) unique and descriptive name
326    def generate_base_name(self):
327        raw_name = ((
328            self.label
329            or (self.key.value if self.key and self.key.type in ["TSTR", "OTHER"] else None)
330            or (f"{self.value.replace(self.backslash_quotation_mark, '')}_{self.type.lower()}"
331                if self.type == "TSTR" and self.value is not None else None)
332            or (f"{self.type.lower()}{self.value}"
333                if self.type in ["INT", "UINT"] and self.value is not None else None)
334            or (next((key for key, value in self.my_types.items() if value == self), None))
335            or (next((key for key, value in self.my_control_groups.items() if value == self), None))
336            or ("_" + self.value if self.type == "OTHER" else None)
337            or ("_" + self.value[0].get_base_name()
338                if self.type in ["LIST", "GROUP"] and self.value else None)
339            or ((self.cbor.value + "_bstr")
340                if self.cbor and self.cbor.type in ["TSTR", "OTHER"] else None)
341            or ((self.key.generate_base_name() + self.type.lower()) if self.key else None)
342            or (f"{self.type.lower()}{self.min_size * 8}"
343                if self.min_size and self.min_size == self.max_size else None)
344            or (f"{self.type.lower()}{self.min_size * 8}-{self.max_size * 8}"
345                if self.min_size and self.max_size else None)
346            or self.type.lower()).replace("-", "_"))
347        return raw_name
348
349    # Base name used for functions, variables, and typedefs.
350    def get_base_name(self):
351        generated = self.generate_base_name()
352        return (self.base_name or generated).replace("-", "_")
353
354    # Set an explicit base name for this element.
355    def set_base_name(self, base_name):
356        self.base_name = base_name
357
358    def set_base_names(self):
359        if self.cbor:
360            self.cbor.set_base_name(self.var_name().strip("_") + "_cbor")
361        if self.key:
362            self.key.set_base_name(self.var_name().strip("_") + "_key")
363
364        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
365            for child in self.value:
366                child.set_base_names()
367        if self.cbor:
368            self.cbor.set_base_names()
369        if self.key:
370            self.key.set_base_names()
371
372    # Add uniqueness to the base name.
373    def id(self, with_prefix=True):
374        raw_name = self.get_base_name()
375        if not with_prefix and self.short_names:
376            return raw_name
377        if (self.id_prefix
378                and (f"{self.id_prefix}_" not in raw_name)
379                and (self.id_prefix != raw_name.strip("_"))):
380            return f"{self.id_prefix}_{raw_name}"
381        if (self.base_stem
382                and (f"{self.base_stem}_" not in raw_name)
383                and (self.base_stem != raw_name.strip("_"))):
384            return f"{self.base_stem}_{raw_name}"
385        return raw_name
386
387    def init_args(self):
388        return (self.default_max_qty,)
389
390    def init_kwargs(self):
391        return {
392            "my_types": self.my_types, "my_control_groups": self.my_control_groups,
393            "short_names": self.short_names}
394
395    def set_id_prefix(self, id_prefix=''):
396        self.id_prefix = id_prefix
397        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
398            for child in self.value:
399                if child.single_func_impl_condition():
400                    child.set_id_prefix(self.generate_base_name())
401                else:
402                    child.set_id_prefix(self.child_base_id())
403        if self.cbor:
404            self.cbor.set_id_prefix(self.child_base_id())
405        if self.key:
406            self.key.set_id_prefix(self.child_base_id())
407
408    # Id to pass to children for them to use as basis for their id/base name.
409    def child_base_id(self):
410        return self.id()
411
412    def get_id_num(self):
413        if self.id_num is None:
414            self.id_num = counter()
415        return self.id_num
416
417    # Human readable representation.
418    def mrepr(self, newline):
419        reprstr = ''
420        if self.quantifier:
421            reprstr += self.quantifier
422        if self.label:
423            reprstr += self.label + ':'
424        for tag in self.tags:
425            reprstr += f"#6.{tag}"
426        if self.key:
427            reprstr += repr(self.key) + " => "
428        if self.is_unambiguous():
429            reprstr += '/'
430        if self.is_unambiguous_repeated():
431            reprstr += '/'
432        reprstr += self.type
433        if self.size:
434            reprstr += '(%d)' % self.size
435        if newline:
436            reprstr += '\n'
437        if self.value:
438            reprstr += pformat(self.value, indent=4, width=1)
439        if self.cbor:
440            reprstr += " cbor: " + repr(self.cbor)
441        return reprstr.replace('\n', '\n    ')
442
443    def _flatten(self):
444        new_value = []
445        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
446            for child in self.value:
447                new_value.extend(
448                    child.flatten(allow_multi=self.type != "UNION"))
449            self.value = new_value
450        if self.key:
451            self.key = self.key.flatten()[0]
452        if self.cbor:
453            self.cbor = self.cbor.flatten()[0]
454
455    def flatten(self, allow_multi=False):
456        self._flatten()
457        if self.type == "OTHER" and self.is_socket and self.value not in self.my_types:
458            return []
459        if self.type in ["GROUP", "UNION"]\
460                and (len(self.value) == 1)\
461                and (not (self.key and self.value[0].key)):
462            self.value[0].min_qty *= self.min_qty
463            self.value[0].max_qty *= self.max_qty
464            if not self.value[0].label:
465                self.value[0].label = self.label
466            if not self.value[0].key:
467                self.value[0].key = self.key
468            self.value[0].tags.extend(self.tags)
469            return self.value
470        elif allow_multi and self.type in ["GROUP"] and self.min_qty == 1 and self.max_qty == 1:
471            return self.value
472        else:
473            return [self]
474
475    def set_min_value(self, min_value):
476        self.min_value = min_value
477
478    def set_max_value(self, max_value):
479        self.max_value = max_value
480
481    # Set the self.type and self.value of this element. For use during CDDL
482    # parsing.
483    def type_and_value(self, new_type, value_generator):
484        if self.type is not None:
485            raise TypeError(
486                "Cannot have two values: %s, %s" %
487                (self.type, new_type))
488        if new_type is None:
489            raise TypeError("Cannot set None as type")
490        if new_type == "UNION" and self.value is not None:
491            raise ValueError("Did not expect multiple parsed values for union")
492
493        self.type = new_type
494        self.set_value(value_generator)
495
496    def set_value(self, value_generator):
497        value = value_generator()
498        self.value = value
499
500        if self.type == "OTHER" and self.value.startswith("$"):
501            self.value = self.value.lstrip("$")
502            self.is_socket = True
503
504        if self.type in ["BSTR", "TSTR"]:
505            if value is not None:
506                self.set_size(len(value))
507        if self.type in ["UINT", "NINT"]:
508            if value is not None:
509                self.size = sizeof(value)
510                self.set_min_value(value)
511                self.set_max_value(value)
512                self.convert_min_max()
513        if self.type == "NINT":
514            self.max_value = -1
515
516    # Set the self.type and self.minValue and self.max_value (or self.min_size and self.max_size
517    # depending on the type) of this element. For use during CDDL parsing.
518    def type_and_range(self, new_type, min_val, max_val, inc_end=True):
519        if not inc_end:
520            max_val -= 1
521        if new_type not in ["INT", "UINT", "NINT"]:
522            raise TypeError(
523                "Only integers (not %s) can have range" %
524                (new_type,))
525        if min_val > max_val:
526            raise TypeError(
527                "Range has larger minimum than maximum (min %d, max %d)" %
528                (min_val, max_val))
529        if min_val == max_val:
530            return self.type_and_value(new_type, min_val)
531        self.type = new_type
532        self.set_min_value(min_val)
533        self.set_max_value(max_val)
534        if new_type in "UINT":
535            self.set_size_range(sizeof(min_val), sizeof(max_val))
536        if new_type == "NINT":
537            self.set_size_range(sizeof(abs(max_val)), sizeof(abs(min_val)))
538        if new_type == "INT":
539            self.set_size_range(None, max(sizeof(abs(max_val)), sizeof(abs(min_val))))
540
541    # Set the self.value and self.size of this element. For use during CDDL
542    # parsing.
543    def type_value_size(self, new_type, value, size):
544        self.type_and_value(new_type, value)
545        self.set_size(size)
546
547    # Set the self.value and self.min_size and self.max_size of this element.
548    # For use during CDDL parsing.
549    def type_value_size_range(self, new_type, value, min_size, max_size):
550        self.type_and_value(new_type, value)
551        self.set_size_range(min_size, max_size)
552
553    # Set the self.label of this element. For use during CDDL parsing.
554    def set_label(self, label):
555        if self.type is not None:
556            raise TypeError("Cannot have label after value: " + label)
557        self.label = label
558
559    # Set the self.quantifier, self.min_qty, and self.max_qty of this element. For
560    # use during CDDL parsing.
561    def set_quantifier(self, quantifier):
562        if self.type is not None:
563            raise TypeError(
564                "Cannot have quantifier after value: " + quantifier)
565
566        quantifier_mapping = [
567            (r"\?", lambda mo: (0, 1)),
568            (r"\*", lambda mo: (0, None)),
569            (r"\+", lambda mo: (1, None)),
570            (r"(.*?)\*\*?(.*)",
571                lambda mo: (int(mo.groups()[0] or "0", 0), int(mo.groups()[1] or "0", 0) or None)),
572        ]
573
574        self.quantifier = quantifier
575        for (reg, handler) in quantifier_mapping:
576            match_obj = match(reg, quantifier)
577            if match_obj:
578                (self.min_qty, self.max_qty) = handler(match_obj)
579                if self.max_qty is None:
580                    self.max_qty = self.default_max_qty
581                return
582        raise ValueError("invalid quantifier: %s" % quantifier)
583
584    def convert_min_max(self):
585        pass
586
587    # Set the self.size of this element. This will also set the self.minValue and self.max_value of
588    # UINT types.
589    # For use during CDDL parsing.
590    def set_size(self, size):
591        if self.type is None:
592            raise TypeError("Cannot have size before value: " + str(size))
593        elif self.type in ["INT", "UINT", "NINT"]:
594            value = 256**size
595            if self.type == "INT":
596                self.max_value = int((value >> 1) - 1)
597            if self.type == "UINT":
598                self.max_value = int(value - 1)
599            if self.type in ["INT", "NINT"]:
600                self.min_value = int(-1 * (value >> 1) + 1)
601            self.convert_min_max()
602        elif self.type in ["BSTR", "TSTR", "FLOAT"]:
603            self.set_size_range(size, size)
604        else:
605            raise TypeError(".size cannot be applied to %s" % self.type)
606
607    # Set the self.minValue and self.max_value or self.min_size and self.max_size of this element
608    # based on what values can be contained within an integer of a certain size. For use during
609    # CDDL parsing.
610    def set_size_range(self, min_size, max_size_in, inc_end=True):
611        max_size = max_size_in if inc_end else max_size_in - 1
612
613        if (min_size and min_size < 0 or max_size and max_size < 0) \
614           or (None not in [min_size, max_size] and min_size > max_size):
615            raise TypeError(
616                "Invalid size range (min %d, max %d)" %
617                (min_size, max_size))
618
619        self.set_min_size(min_size)
620        self.set_max_size(max_size)
621
622    # Set self.min_size, and self.minValue if type is UINT.
623    def set_min_size(self, min_size):
624        if self.type == "UINT":
625            self.minValue = 256**min(0, abs(min_size - 1)) if min_size else None
626        self.min_size = min_size if min_size else None
627        self.convert_min_max()
628
629    # Set self.max_size, and self.max_value if type is UINT.
630    def set_max_size(self, max_size):
631        if self.type == "UINT" and max_size and self.max_value is None:
632            if max_size > 8:
633                raise TypeError(
634                    "Size too large for integer. size %d" %
635                    max_size)
636            self.max_value = 256**max_size - 1
637        self.max_size = max_size
638        self.convert_min_max()
639
640    # Set the self.cbor of this element. For use during CDDL parsing.
641    def set_cbor(self, cbor, cborseq):
642        if self.type != "BSTR":
643            raise TypeError(
644                "%s must be used with bstr." %
645                (".cborseq" if cborseq else ".cbor",))
646        self.cbor = cbor
647        if cborseq:
648            self.cbor.max_qty = self.default_max_qty
649
650    def set_bits(self, bits):
651        if self.type != "UINT":
652            raise TypeError(".bits must be used with bstr.")
653        self.bits = bits
654
655    # Set the self.key of this element. For use during CDDL parsing.
656    def set_key(self, key):
657        if self.key is not None:
658            raise TypeError("Cannot have two keys: " + key)
659        if key.type == "GROUP":
660            raise TypeError("A key cannot be a group because it might represent more than 1 type.")
661        self.key = key
662
663    # Set the self.label OR self.key of this element. In the CDDL "foo: bar", foo can be either a
664    # label or a key depending on whether it is in a map. This code uses a slightly different method
665    # for choosing between label and key. If the string is recognized as a type, it is treated as a
666    # key. For use during CDDL parsing.
667    def set_key_or_label(self, key_or_label):
668        if key_or_label.type == "OTHER" and key_or_label.value not in self.my_types:
669            self.set_label(key_or_label.value)
670        else:
671            if key_or_label.type == "OTHER" and self.label is None:
672                self.set_label(key_or_label.value)
673            self.set_key(key_or_label)
674
675    def add_tag(self, tag):
676        self.tags.append(int(tag))
677
678    # Append to the self.value of this element. Used with the "UNION" type, which has
679    # a python list as self.value. The list represents the "children" of the
680    # type. For use during CDDL parsing.
681    def union_add_value(self, value, doubleslash=False):
682        if self.type != "UNION":
683            convert_val = copy(self)
684            self.__init__(*self.init_args(), **self.init_kwargs())
685            self.type_and_value("UNION", lambda: [convert_val])
686
687            self.base_name = convert_val.base_name
688            convert_val.base_name = None
689            self.base_stem = convert_val.base_stem
690
691            if not doubleslash:
692                self.label = convert_val.label
693                self.key = convert_val.key
694                self.quantifier = convert_val.quantifier
695                self.max_qty = convert_val.max_qty
696                self.min_qty = convert_val.min_qty
697
698                convert_val.label = None
699                convert_val.key = None
700                convert_val.quantifier = None
701                convert_val.max_qty = 1
702                convert_val.min_qty = 1
703        self.value.append(value)
704
705    def convert_to_key(self):
706        convert_val = copy(self)
707        self.__init__(*self.init_args(), **self.init_kwargs())
708        self.set_key(convert_val)
709
710        self.label = convert_val.label
711        self.quantifier = convert_val.quantifier
712        self.max_qty = convert_val.max_qty
713        self.min_qty = convert_val.min_qty
714        self.base_name = convert_val.base_name
715        self.base_stem = convert_val.base_stem
716
717        convert_val.label = None
718        convert_val.quantifier = None
719        convert_val.max_qty = 1
720        convert_val.min_qty = 1
721        convert_val.base_name = None
722
723    # Parse from the beginning of instr (string) until a full element has been parsed. self will
724    # become that element. This function is recursive, so if a nested element
725    # ("MAP"/"LIST"/"UNION"/"GROUP") is encountered, this function will create new instances and add
726    # them to self.value as a list. Likewise, if a key or cbor definition is encountered, a new
727    # element will be created and assigned to self.key or self.cbor. When new elements are created,
728    # getValue() is called on those elements, via parse().
729    def get_value(self, instr):
730        match_uint = r"(0x[0-9a-fA-F]+|0o[0-7]+|0b[01]+|\d+)"
731        match_int = r"(-?" + match_uint + ")"
732        match_nint = r"(-" + match_uint + ")"
733
734        # The following regexes match different parts of the element. The order of the list is
735        # important because it implements the operator precendence defined in the CDDL spec. Note
736        # that some regexes are inserted afterwards because they involve a match of a concatenation
737        # of all the initial regexes (with a '|' between each element).
738        types = [
739            (r'\[(?P<item>(?>[^[\]]+|(?R))*)\]',
740             lambda list_str: self.type_and_value("LIST", lambda: self.parse(list_str))),
741            (r'\((?P<item>(?>[^\(\)]+|(?R))*)\)',
742             lambda group_str: self.type_and_value("GROUP", lambda:self.parse(group_str))),
743            (r'{(?P<item>(?>[^{}]+|(?R))*)}',
744             lambda map_str: self.type_and_value("MAP", lambda: self.parse(map_str))),
745            (r'\/\/\s*(?P<item>.+?)(?=\/\/|\Z)',
746             lambda union_str: self.union_add_value(
747                 self.parse("(%s)" % union_str if ',' in union_str else union_str)[0],
748                 doubleslash=True)),
749            (r'(\=\>)',
750             lambda _: self.convert_to_key()),
751            (r'(?P<item>\w)\s*\:',
752             lambda key_str: self.set_key_or_label(keystr)),
753            (r'([+*?])',
754             self.set_quantifier),
755            (r'(' + match_uint + r'\*\*?' + match_uint + r'?)',
756             self.set_quantifier),
757            (r'uint(?![\w-])',
758             lambda _: self.type_and_value("UINT", lambda: None)),
759            (r'nint(?![\w-])',
760             lambda _: self.type_and_value("NINT", lambda: None)),
761            (r'int(?![\w-])',
762             lambda _: self.type_and_value("INT", lambda: None)),
763            (r'float(?![\w-])',
764             lambda _: self.type_and_value("FLOAT", lambda: None)),
765            (r'float16(?![\w-])',
766             lambda _: self.type_value_size("FLOAT", lambda: None, 2)),
767            (r'float16-32(?![\w-])',
768             lambda _: self.type_value_size_range("FLOAT", lambda: None, 2, 4)),
769            (r'float32(?![\w-])',
770             lambda _: self.type_value_size("FLOAT", lambda: None, 4)),
771            (r'float32-64(?![\w-])',
772             lambda _: self.type_value_size_range("FLOAT", lambda: None, 4, 8)),
773            (r'float64(?![\w-])',
774             lambda _: self.type_value_size("FLOAT", lambda: None, 8)),
775            (r'\-?\d*\.\d+',
776             lambda num: self.type_and_value("FLOAT", lambda: float(num))),
777            (match_uint + r'\.\.' + match_uint,
778             lambda _range: self.type_and_range(
779                 "UINT", *map(lambda num: int(num, 0), _range.split("..")))),
780            (match_nint + r'\.\.' + match_uint,
781             lambda _range: self.type_and_range(
782                 "INT", *map(lambda num: int(num, 0), _range.split("..")))),
783            (match_nint + r'\.\.' + match_nint,
784             lambda _range: self.type_and_range(
785                 "NINT", *map(lambda num: int(num, 0), _range.split("..")))),
786            (match_uint + r'\.\.\.' + match_uint,
787             lambda _range: self.type_and_range(
788                 "UINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
789            (match_nint + r'\.\.\.' + match_uint,
790             lambda _range: self.type_and_range(
791                 "INT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
792            (match_nint + r'\.\.\.' + match_nint,
793             lambda _range: self.type_and_range(
794                 "NINT", *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
795            (match_nint,
796             lambda num: self.type_and_value("NINT", lambda: int(num, 0))),
797            (match_uint,
798             lambda num: self.type_and_value("UINT", lambda: int(num, 0))),
799            (r'bstr(?!\w)',
800             lambda _: self.type_and_value("BSTR", lambda: None)),
801            (r'\'(?P<item>.*?)(?<!\\)\'',
802             lambda string: self.type_and_value("BSTR", lambda: string)),
803            (r'tstr(?!\w)',
804             lambda _: self.type_and_value("TSTR", lambda: None)),
805            (r'\"(?P<item>.*?)(?<!\\)\"',
806             lambda string: self.type_and_value("TSTR", lambda: string)),
807            (r'bool(?!\w)',
808             lambda _: self.type_and_value("BOOL", lambda: None)),
809            (r'true(?!\w)',
810             lambda _: self.type_and_value("BOOL", lambda: True)),
811            (r'false(?!\w)',
812             lambda _: self.type_and_value("BOOL", lambda: False)),
813            (r'nil(?!\w)',
814             lambda _: self.type_and_value("NIL", lambda: None)),
815            (r'undefined(?!\w)',
816             lambda _: self.type_and_value("UNDEF", lambda: None)),
817            (r'any(?!\w)',
818             lambda _: self.type_and_value("ANY", lambda: None)),
819            (r'#6\.(?P<item>\d+)',
820             self.add_tag),
821            (r'(\$?\$?[\w-]+)',
822             lambda other_str: self.type_and_value("OTHER", lambda: other_str)),
823            (r'\.size \(?(?P<item>' + match_int + r'\.\.' + match_int + r')\)?',
824             lambda _range: self.set_size_range(*map(lambda num: int(num, 0), _range.split("..")))),
825            (r'\.size \(?(?P<item>' + match_int + r'\.\.\.' + match_int + r')\)?',
826             lambda _range: self.set_size_range(
827                 *map(lambda num: int(num, 0), _range.split("...")), inc_end=False)),
828            (r'\.size \(?(?P<item>' + match_uint + r')\)?',
829             lambda size: self.set_size(int(size, 0))),
830            (r'\.gt \(?(?P<item>' + match_int + r')\)?',
831             lambda minvalue: self.set_min_value(int(minvalue, 0) + 1)),
832            (r'\.lt \(?(?P<item>' + match_int + r')\)?',
833             lambda maxvalue: self.set_max_value(int(maxvalue, 0) - 1)),
834            (r'\.ge \(?(?P<item>' + match_int + r')\)?',
835             lambda minvalue: self.set_min_value(int(minvalue, 0))),
836            (r'\.le \(?(?P<item>' + match_int + r')\)?',
837             lambda maxvalue: self.set_max_value(int(maxvalue, 0))),
838            (r'\.eq \(?(?P<item>' + match_int + r')\)?',
839             lambda value: self.set_value(lambda: int(value, 0))),
840            (r'\.eq \"(?P<item>.*?)(?<!\\)\"',
841             lambda value: self.set_value(lambda: value)),
842            (r'\.cbor (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
843             lambda type_str: self.set_cbor(self.parse(type_str)[0], False)),
844            (r'\.cbor (?P<item>[^\s,]+)',
845             lambda type_str: self.set_cbor(self.parse(type_str)[0], False)),
846            (r'\.cborseq (\((?P<item>(?>[^\(\)]+|(?1))*)\))',
847             lambda type_str: self.set_cbor(self.parse(type_str)[0], True)),
848            (r'\.cborseq (?P<item>[^\s,]+)',
849             lambda type_str: self.set_cbor(self.parse(type_str)[0], True)),
850            (r'\.bits (?P<item>[\w-]+)',
851             lambda bits_str: self.set_bits(bits_str))
852        ]
853        all_type_regex = '|'.join([regex for (regex, _) in (types[:3] + types[5:])])
854        for i in range(0, all_type_regex.count("item")):
855            all_type_regex = all_type_regex.replace("item", "it%dem" % i, 1)
856        types.insert(5, (r'(?P<item>' + all_type_regex + r')\s*\:',
857                         lambda key_str: self.set_key_or_label(self.parse(key_str)[0])))
858        types.insert(6, (r'(?P<item>' + all_type_regex + r')\s*\=\>',
859                         lambda key_str: self.set_key(self.parse(key_str)[0])))
860        types.insert(7, (r'\/\s*(?P<item>((' + all_type_regex + r')\s*)+?)(?=\/|\,|\Z)',
861                         lambda union_str: self.union_add_value(self.parse(union_str)[0])))
862
863        # Keep parsing until a comma, or to the end of the string.
864        while instr != '' and instr[0] != ',':
865            match_obj = None
866            for (reg, handler) in types:
867                match_obj = match(reg, instr)
868                if match_obj:
869                    try:
870                        match_str = match_obj.group("item")
871                    except IndexError:
872                        match_str = match_obj.group(0)
873                    try:
874                        handler(match_str)
875                    except Exception as e:
876                        raise Exception("Failed while parsing this: '%s'" % match_str) from e
877                    self.match_str += match_str
878                    old_len = len(instr)
879                    instr = sub(reg, '', instr, count=1).lstrip()
880                    if old_len == len(instr):
881                        raise Exception("empty match")
882                    break
883
884            if not match_obj:
885                raise TypeError("Could not parse this: '%s'" % instr)
886
887        instr = instr[1:]
888        if not self.type:
889            raise ValueError("No proper value while parsing: %s" % instr)
890
891        # Return the unparsed part of the string.
892        return instr.strip()
893
894    # For checking whether this element has a key (i.e. to check that it is a valid "MAP" child).
895    # This must have some recursion since CDDL allows the key to be hidden
896    # behind layers of indirection.
897    def elem_has_key(self):
898        return self.key is not None\
899            or (self.type == "OTHER" and self.my_types[self.value].elem_has_key())\
900            or (self.type in ["GROUP", "UNION"]
901                and all(child.elem_has_key() for child in self.value))
902
903    # Function for performing validations that must be done after all parsing is complete. This is
904    # recursive, so it will post_validate all its children + key + cbor.
905    def post_validate(self):
906        # Validation of this element.
907        if self.type in ["LIST", "MAP"]:
908            none_keys = [child for child in self.value if not child.elem_has_key()]
909            child_keys = [child for child in self.value if child not in none_keys]
910            if self.type == "MAP" and none_keys:
911                raise TypeError(
912                    "Map member(s) must have key: " + str(none_keys) + " pointing to "
913                    + str(
914                        [self.my_types[elem.value] for elem in none_keys
915                            if elem.type == "OTHER"]))
916            if self.type == "LIST" and child_keys:
917                raise TypeError(
918                    str(self) + linesep
919                    + "List member(s) cannot have key: " + str(child_keys) + " pointing to "
920                    + str(
921                        [self.my_types[elem.value] for elem in child_keys
922                            if elem.type == "OTHER"]))
923        if self.type == "OTHER":
924            if self.value not in self.my_types.keys() or not isinstance(
925                    self.my_types[self.value], type(self)):
926                raise TypeError("%s has not been parsed." % self.value)
927        if self.type == "LIST":
928            for child in self.value[:-1]:
929                if child.type == "ANY":
930                    if child.min_qty != child.max_qty:
931                        raise TypeError(f"ambiguous quantity of 'any' is not supported in list, "
932                                        + "except as last element:\n{str(child)}")
933        if self.type == "UNION" and len(self.value) > 1:
934            if any(((not child.key and child.type == "ANY") or (
935                    child.key and child.key.type == "ANY")) for child in self.value):
936                raise TypeError(
937                    "'any' inside union is not supported since it would always be triggered.")
938
939        # Validation of child elements.
940        if self.type in ["MAP", "LIST", "UNION", "GROUP"]:
941            for child in self.value:
942                child.post_validate()
943        if self.key:
944            self.key.post_validate()
945        if self.cbor:
946            self.cbor.post_validate()
947
948    def post_validate_control_group(self):
949        if self.type != "GROUP":
950            raise TypeError("control groups must be of GROUP type.")
951        for c in self.value:
952            if c.type != "UINT" or c.value is None or c.value < 0:
953                raise TypeError("control group members must be literal positive integers.")
954
955    # Parses entire instr and returns a list of instances.
956    def parse(self, instr):
957        instr = instr.strip()
958        values = []
959        while instr != '':
960            value = type(self)(*self.init_args(), **self.init_kwargs(), base_stem=self.base_stem)
961            instr = value.get_value(instr)
962            values.append(value)
963        return values
964
965    def __repr__(self):
966        return self.mrepr(False)
967
968
969class CddlXcoder(CddlParser):
970
971    def __init__(self, *args, **kwargs):
972        super(CddlXcoder, self).__init__(*args, **kwargs)
973
974        # The prefix used for C code accessing this element, i.e. the struct
975        # hierarchy leading up to this element.
976        self.accessPrefix = None
977        # Used as a guard against endless recursion in self.dependsOn()
978        self.dependsOnCall = False
979        self.skipped = False
980
981    # Name of variables and enum members for this element.
982    def var_name(self, with_prefix=False):
983        if with_prefix or not self.short_names:
984            name = ("_%s" % self.id(with_prefix=with_prefix))
985        else:
986            name = self.id(with_prefix=with_prefix)
987            if name in ["int", "bool", "float"]:
988                name = "_" + name
989        return name
990
991    def skip_condition(self):
992        if self.skipped:
993            return True
994        if self.type in ["LIST", "MAP", "GROUP"]:
995            return not self.multi_val_condition()
996        if self.type == "OTHER":
997            return ((not self.repeated_multi_var_condition())
998                    and (not self.multi_var_condition())
999                    and (self.single_func_impl_condition() or self in self.my_types.values()))
1000        return False
1001
1002    def set_skipped(self, skipped):
1003        if self.range_check_condition() \
1004           and self.repeated_single_func_impl_condition() \
1005           and not self.key:
1006            self.skipped = True
1007        else:
1008            self.skipped = skipped
1009
1010    # Recursively set the access prefix for this element and all its children.
1011    def set_access_prefix(self, prefix):
1012        self.accessPrefix = prefix
1013        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
1014            list(map(lambda child: child.set_access_prefix(self.var_access()),
1015                     self.value))
1016            list(map(lambda child: child.set_skipped(self.skip_condition()),
1017                     self.value))
1018        elif self in self.my_types.values() and self.type != "OTHER":
1019            self.set_skipped(not self.multi_member())
1020        if self.key is not None:
1021            self.key.set_access_prefix(self.var_access())
1022        if self.cbor_var_condition():
1023            self.cbor.set_access_prefix(self.var_access())
1024
1025    # Whether this type has multiple member variables.
1026    def multi_member(self):
1027        return self.multi_var_condition() or self.repeated_multi_var_condition()
1028
1029    def is_unambiguous_value(self):
1030        return (self.type in ["NIL", "UNDEF", "ANY"]
1031                or (self.type in ["INT", "NINT", "UINT", "FLOAT", "BSTR", "TSTR", "BOOL"]
1032                    and self.value is not None)
1033                or (self.type == "OTHER" and self.my_types[self.value].is_unambiguous()))
1034
1035    def is_unambiguous_repeated(self):
1036        return (self.is_unambiguous_value()
1037                and (self.key is None or self.key.is_unambiguous_repeated())
1038                or (self.type in ["LIST", "GROUP", "MAP"] and len(self.value) == 0)
1039                or (self.type in ["LIST", "GROUP", "MAP"]
1040                    and all((child.is_unambiguous() for child in self.value))))
1041
1042    def is_unambiguous(self):
1043        return (self.is_unambiguous_repeated() and (self.min_qty == self.max_qty))
1044
1045    # Create an access prefix based on an existing prefix, delimiter and a
1046    # suffix.
1047    def access_append_delimiter(self, prefix, delimiter, *suffix):
1048        assert prefix is not None, "No access prefix for %s" % self.var_name()
1049        return delimiter.join((prefix,) + suffix)
1050
1051    # Create an access prefix from this element's prefix, delimiter and a
1052    # provided suffix.
1053    def access_append(self, *suffix):
1054        suffix = list(suffix)
1055        return self.access_append_delimiter(self.accessPrefix, '.', *suffix)
1056
1057    # "Path" to this element's variable.
1058    # If full is false, the path to the repeated part is returned.
1059    def var_access(self):
1060        if self.is_unambiguous():
1061            return "NULL"
1062        return self.access_append()
1063
1064    # "Path" to access this element's actual value variable.
1065    def val_access(self):
1066        if self.is_unambiguous_repeated():
1067            return "NULL"
1068        if self.skip_condition():
1069            return self.var_access()
1070        return self.access_append(self.var_name())
1071
1072    def repeated_val_access(self):
1073        if self.is_unambiguous_repeated():
1074            return "NULL"
1075        return self.access_append(self.var_name())
1076
1077    # Whether to include a "present" variable for this element.
1078    def present_var_condition(self):
1079        return self.min_qty == 0 and isinstance(self.max_qty, int) and self.max_qty <= 1
1080
1081    # Whether to include a "count" variable for this element.
1082    def count_var_condition(self):
1083        return isinstance(self.max_qty, str) or self.max_qty > 1
1084
1085    # Whether to include a "cbor" variable for this element.
1086    def is_cbor(self):
1087        return (self.type not in ["NIL", "UNDEF", "ANY"]) \
1088            and ((self.type != "OTHER") or (self.my_types[self.value].is_cbor()))
1089
1090    # Whether to include a "cbor" variable for this element.
1091    def cbor_var_condition(self):
1092        return (self.cbor is not None) and self.cbor.is_cbor()
1093
1094    # Whether to include a "choice" variable for this element.
1095    def choice_var_condition(self):
1096        return self.type == "UNION"
1097
1098    # Whether this specific type is a key.
1099    def reduced_key_var_condition(self):
1100        if self.key is not None:
1101            return True
1102        return False
1103
1104    # Whether to include a "key" variable for this element.
1105    def key_var_condition(self):
1106        if self.reduced_key_var_condition():
1107            return True
1108        if self.type == "OTHER" and self.my_types[self.value].key_var_condition():
1109            return True
1110        if (self.type in ["GROUP", "UNION"]
1111                and len(self.value) >= 1
1112                and self.value[0].reduced_key_var_condition()):
1113            return True
1114        return False
1115
1116    # Whether this value adds any repeated elements by itself. I.e. excluding
1117    # multiple elements from children.
1118    def self_repeated_multi_var_condition(self):
1119        return (self.key_var_condition()
1120                or self.cbor_var_condition()
1121                or self.choice_var_condition())
1122
1123    # Whether this element's actual value has multiple members.
1124    def multi_val_condition(self):
1125        return (
1126            self.type in ["LIST", "MAP", "GROUP", "UNION"]
1127            and (len(self.value) > 1 or (len(self.value) == 1 and self.value[0].multi_member())))
1128
1129    # Whether any extra variables are to be included for this element for each
1130    # repetition.
1131    def repeated_multi_var_condition(self):
1132        return self.self_repeated_multi_var_condition() or self.multi_val_condition()
1133
1134    # Whether any extra variables are to be included for this element outside
1135    # of repetitions.
1136    # Also, whether this element must involve a call to multi_xcode(), i.e. unless
1137    # it's repeated exactly once.
1138    def multi_var_condition(self):
1139        return self.present_var_condition() or self.count_var_condition()
1140
1141    # Whether this element needs a check (memcmp) for a string value.
1142    def range_check_condition(self):
1143        if self.type == "OTHER":
1144            return self.my_types[self.value].range_check_condition()
1145        if self.type not in ["INT", "NINT", "UINT", "BSTR", "TSTR"]:
1146            return False
1147        if self.value is not None:
1148            return False
1149        if self.type in ["INT", "NINT", "UINT"] \
1150                and (self.min_value is not None or self.max_value is not None):
1151            return True
1152        if self.type == "UINT" and self.bits:
1153            return True
1154        if self.type in ["BSTR", "TSTR"] \
1155                and (self.min_size is not None or self.max_size is not None):
1156            return True
1157        return False
1158
1159    # Whether this element should have a typedef in the code.
1160    def type_def_condition(self):
1161        if self in self.my_types.values() and self.multi_member() and not self.is_unambiguous():
1162            return True
1163        return False
1164
1165    # Whether this type needs a typedef for its repeated part.
1166    def repeated_type_def_condition(self):
1167        return (
1168            self.repeated_multi_var_condition()
1169            and self.multi_var_condition()
1170            and not self.is_unambiguous_repeated())
1171
1172    # Whether this element needs its own encoder/decoder function.
1173    def single_func_impl_condition(self):
1174        return (
1175            False
1176            or self.reduced_key_var_condition()
1177            or self.cbor_var_condition()
1178            or (self.tags and self in self.my_types.values())
1179            or self.type_def_condition()
1180            or (self.type in ["LIST", "MAP", "GROUP"] and len(self.value) != 0))
1181
1182    # Whether this element needs its own encoder/decoder function.
1183    def repeated_single_func_impl_condition(self):
1184        return self.repeated_type_def_condition() \
1185            or (self.type in ["LIST", "MAP", "GROUP"] and self.multi_member()) \
1186            or (
1187                self.multi_var_condition()
1188                and (self.self_repeated_multi_var_condition() or self.range_check_condition()))
1189
1190    def int_val(self):
1191        if self.key:
1192            return self.key.int_val()
1193        elif self.type in ("UINT", "NINT") and self.is_unambiguous():
1194            return self.value
1195        elif self.type == "GROUP" and not self.count_var_condition():
1196            return self.value[0].int_val()
1197        elif self.type == "OTHER" \
1198                and not self.count_var_condition() \
1199                and not self.single_func_impl_condition() \
1200                and not self.my_types[self.value].single_func_impl_condition():
1201            return self.my_types[self.value].int_val()
1202        return None
1203
1204    def is_int_disambiguated(self):
1205        return self.int_val() is not None
1206
1207    def all_children_int_disambiguated(self):
1208        values = set(child.int_val() for child in self.value)
1209        retval = (len(values) == len(self.value)) and None not in values \
1210            and max(values) <= INT32_MAX and min(values) >= INT32_MIN
1211        return retval
1212
1213    # Name of the "present" variable for this element.
1214    def present_var_name(self):
1215        return "%s_present" % (self.var_name())
1216
1217    # Full "path" of the "present" variable for this element.
1218    def present_var_access(self):
1219        return self.access_append(self.present_var_name())
1220
1221    # Name of the "count" variable for this element.
1222    def count_var_name(self):
1223        return "%s_count" % (self.var_name())
1224
1225    # Full "path" of the "count" variable for this element.
1226    def count_var_access(self):
1227        return self.access_append(self.count_var_name())
1228
1229    # Name of the "choice" variable for this element.
1230    def choice_var_name(self):
1231        return self.var_name() + "_choice"
1232
1233    # Name of the enum entry for this element.
1234    def enum_var_name(self, int_val=False):
1235        if not int_val:
1236            retval = self.var_name(with_prefix=True)
1237        else:
1238            retval = f"{self.var_name(with_prefix=True)} = {self.int_val()}"
1239        return retval
1240
1241    # Full "path" of the "choice" variable for this element.
1242    def choice_var_access(self):
1243        return self.access_append(self.choice_var_name())
1244
1245
1246class CddlValidationError(Exception):
1247    pass
1248
1249
1250# Subclass of tuple for holding key,value pairs. This is to make it possible to use isinstance() to
1251# separate it from other tuples.
1252class KeyTuple(tuple):
1253    def __new__(cls, *in_tuple):
1254        return super(KeyTuple, cls).__new__(cls, *in_tuple)
1255
1256
1257# Convert data between CBOR, JSON and YAML, and validate against the provided CDDL.
1258# Decode and validate CBOR into Python structures to be able to make Python scripts that manipulate
1259# CBOR code.
1260class DataTranslator(CddlXcoder):
1261    # Format a Python object for printing by adding newlines and indentation.
1262    @staticmethod
1263    def format_obj(obj):
1264        formatted = pformat(obj)
1265        out_str = ""
1266        indent = 0
1267        new_line = True
1268        for c in formatted:
1269            if new_line:
1270                if c == " ":
1271                    continue
1272                new_line = False
1273            out_str += c
1274            if c in "[(":
1275                indent += 1
1276            if c in ")]" and indent > 0:
1277                indent -= 1
1278            if c in "[(,":
1279                out_str += linesep
1280                out_str += "  " * indent
1281                new_line = True
1282        return out_str
1283
1284    # Override the id() function
1285    def id(self):
1286        return self.generate_base_name().strip("_")
1287
1288    # Override the var_name()
1289    def var_name(self):
1290        return self.id()
1291
1292    # Check a condition and raise a CddlValidationError if not.
1293    def _decode_assert(self, test, msg=""):
1294        if not test:
1295            raise CddlValidationError(f"Data did not decode correctly {'('+msg+')' if msg else ''}")
1296
1297    # Check that no unexpected tags are attached to this data. Return whether a tag was present.
1298    def _check_tag(self, obj):
1299        tags = copy(self.tags)  # All expected tags
1300        # Process all tags present in obj
1301        while isinstance(obj, CBORTag):
1302            if obj.tag in tags or self.type == "ANY":
1303                if obj.tag in tags:
1304                    tags.remove(obj.tag)
1305                obj = obj.value
1306                continue
1307            elif self.type in ["OTHER", "GROUP", "UNION"]:
1308                break
1309            self._decode_assert(False, f"Tag ({obj.tag}) not expected for {self}")
1310        # Check that all expected tags were found in obj.
1311        self._decode_assert(not tags, f"Expected tags ({tags}), but none present.")
1312        return obj
1313
1314    # Return our expected python type as returned by cbor2.
1315    def _expected_type(self):
1316        return {
1317            "UINT": lambda: int,
1318            "INT": lambda: int,
1319            "NINT": lambda: int,
1320            "FLOAT": lambda: float,
1321            "TSTR": lambda: str,
1322            "BSTR": lambda: bytes,
1323            "NIL": lambda: type(None),
1324            "UNDEF": lambda: type(undefined),
1325            "ANY": lambda: (int, float, str, bytes, type(None), type(undefined), bool, list, dict),
1326            "BOOL": lambda: bool,
1327            "LIST": lambda: (tuple, list),
1328            "MAP": lambda: dict,
1329        }[self.type]()
1330
1331    # Check that the decoded object has the correct type.
1332    def _check_type(self, obj):
1333        if self.type not in ["OTHER", "GROUP", "UNION"]:
1334            exp_type = self._expected_type()
1335            self._decode_assert(
1336                isinstance(obj, exp_type),
1337                f"{str(self)}: Wrong type of {str(obj)}, expected {str(exp_type)}")
1338
1339    # Check that the decode value conforms to the restrictions in the CDDL.
1340    def _check_value(self, obj):
1341        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR", "BSTR", "BOOL"] \
1342                and self.value is not None:
1343            self._decode_assert(
1344                self.value == obj,
1345                f"{obj} should have value {self.value} according to {self.var_name()}")
1346        if self.type in ["UINT", "INT", "NINT", "FLOAT"]:
1347            if self.min_value is not None:
1348                self._decode_assert(obj >= self.min_value, "Minimum value: " + str(self.min_value))
1349            if self.max_value is not None:
1350                self._decode_assert(obj <= self.max_value, "Maximum value: " + str(self.max_value))
1351        if self.type == "UINT":
1352            if self.bits:
1353                mask = sum(((1 << b.value) for b in self.my_control_groups[self.bits].value))
1354                self._decode_assert(not (obj & ~mask), "Allowed bitmask: " + bin(mask))
1355        if self.type in ["TSTR", "BSTR"]:
1356            if self.min_size is not None:
1357                self._decode_assert(
1358                    len(obj) >= self.min_size, "Minimum length: " + str(self.min_size))
1359            if self.max_size is not None:
1360                self._decode_assert(
1361                    len(obj) <= self.max_size, "Maximum length: " + str(self.max_size))
1362
1363    # Check that the object is not a KeyTuple since that means it hasn't been properly processed.
1364    def _check_key(self, obj):
1365        self._decode_assert(
1366            not isinstance(obj, KeyTuple), "Unexpected key found: (key,value)=" + str(obj))
1367
1368    # Recursively remove intermediate objects that have single members. Keep lists as is.
1369    def _flatten_obj(self, obj):
1370        if isinstance(obj, tuple) and len(obj) == 1:
1371            return self._flatten_obj(obj[0])
1372        return obj
1373
1374    # Return the contents of a list if it has a single member and it's name is the same as us.
1375    def _flatten_list(self, name, obj):
1376        if (isinstance(obj, list)
1377                and len(obj) == 1
1378                and (isinstance(obj[0], list) or isinstance(obj[0], tuple))
1379                and len(obj[0]) == 1
1380                and hasattr(obj[0], name)):
1381            return [obj[0][0]]
1382        return obj
1383
1384    # Construct a namedtuple object from my_list. my_list contains tuples of name/value.
1385    # Also, attempt to flatten redundant levels of abstraction.
1386    def _construct_obj(self, my_list):
1387        if my_list == []:
1388            return None
1389        names, values = tuple(zip(*my_list))
1390        if len(values) == 1:
1391            values = (self._flatten_obj(values[0]), )
1392        values = tuple(self._flatten_list(names[i], values[i]) for i in range(len(values)))
1393        assert (not any((isinstance(elem, KeyTuple) for elem in values))), \
1394            f"KeyTuple not processed: {values}"
1395        return namedtuple("_", names)(*values)
1396
1397    # Add construct obj and add it to my_list if relevant. Also, process any KeyTuples present.
1398    def _add_if(self, my_list, obj, expect_key=False, name=None):
1399        if expect_key and self.type == "OTHER" and self.key is None:
1400            self.my_types[self.value]._add_if(my_list, obj)
1401            return
1402        if self.is_unambiguous():
1403            return
1404        if isinstance(obj, list):
1405            for i in range(len(obj)):
1406                if isinstance(obj[i], KeyTuple):
1407                    retvals = list()
1408                    self._add_if(retvals, obj[i])
1409                    obj[i] = self._construct_obj(retvals)
1410                if self.type == "BSTR" and self.cbor_var_condition() and isinstance(obj[i], bytes):
1411                    assert all((isinstance(o, bytes) for o in obj)), \
1412                           """Unsupported configuration for cbor bstr. If a list contains a
1413CBOR-formatted bstr, all elements must be bstrs. If not, it is a programmer error."""
1414        if isinstance(obj, KeyTuple):
1415            key, obj = obj
1416            if key is not None:
1417                self.key._add_if(my_list, key, name=self.var_name() + "_key")
1418        if self.type == "BSTR" and self.cbor_var_condition():
1419            # If a bstr is CBOR-formatted, add both the string and the decoding of the string here
1420            if isinstance(obj, list) and all((isinstance(o, bytes) for o in obj)):
1421                # One or more bstr in a list (i.e. it is optional or repeated)
1422                my_list.append((name or self.var_name(), [self.cbor.decode_str(o) for o in obj]))
1423                my_list.append(((name or self.var_name()) + "_bstr", obj))
1424                return
1425            if isinstance(obj, bytes):
1426                my_list.append((name or self.var_name(), self.cbor.decode_str(obj)))
1427                my_list.append(((name or self.var_name()) + "_bstr", obj))
1428                return
1429        my_list.append((name or self.var_name(), obj))
1430
1431    # Throw CddlValidationError if iterator is not empty. This consumes one element if present.
1432    def _iter_is_empty(self, it):
1433        try:
1434            val = next(it)
1435        except StopIteration:
1436            return True
1437        raise CddlValidationError(
1438            f"Iterator not consumed while parsing \n{self}\nRemaining elements:\n elem: "
1439            + "\n elem: ".join(str(elem) for elem in ([val] + list(it))))
1440
1441    # Get next element from iterator, throw CddlValidationError instead of StopIteration.
1442    def _iter_next(self, it):
1443        try:
1444            next_obj = next(it)
1445            return next_obj
1446        except StopIteration:
1447            raise CddlValidationError("Iterator empty")
1448
1449    # Decode single CDDL value, excluding repetitions
1450    def _decode_single_obj(self, obj):
1451        self._check_key(obj)
1452        obj = self._check_tag(obj)
1453        self._check_type(obj)
1454        self._check_value(obj)
1455        if self.type in ["UINT", "INT", "NINT", "FLOAT", "TSTR",
1456                         "BSTR", "BOOL", "NIL", "UNDEF", "ANY"]:
1457            return obj
1458        elif self.type == "OTHER":
1459            return self.my_types[self.value]._decode_single_obj(obj)
1460        elif self.type == "LIST":
1461            retval = list()
1462            child_val = iter(obj)
1463            for child in self.value:
1464                ret = child._decode_full(child_val)
1465                child_val, child_obj = ret
1466                child._add_if(retval, child_obj)
1467            self._iter_is_empty(child_val)
1468            return self._construct_obj(retval)
1469        elif self.type == "MAP":
1470            retval = list()
1471            child_val = iter(KeyTuple(item) for item in obj.items())
1472            for child in self.value:
1473                child_val, child_key_val = child._decode_full(child_val)
1474                child._add_if(retval, child_key_val, expect_key=True)
1475            self._iter_is_empty(child_val)
1476            return self._construct_obj(retval)
1477        elif self.type == "UNION":
1478            retval = list()
1479            for child in self.value:
1480                try:
1481                    child_obj = child._decode_single_obj(obj)
1482                    child._add_if(retval, child_obj)
1483                    retval.append(("union_choice", child.var_name()))
1484                    return self._construct_obj(retval)
1485                except CddlValidationError as c:
1486                    self.errors.append(str(c))
1487            self._decode_assert(False, "No matches for union: " + str(self))
1488        assert False, "Unexpected type: " + self.type
1489
1490    # Decode key and value in the form of a KeyTuple
1491    def _handle_key(self, next_obj):
1492        self._decode_assert(
1493            isinstance(next_obj, KeyTuple), f"Expected key: {self.key} value=" + pformat(next_obj))
1494        key, obj = next_obj
1495        key_res = self.key._decode_single_obj(key)
1496        obj_res = self._decode_single_obj(obj)
1497        res = KeyTuple((key_res if not self.key.is_unambiguous() else None, obj_res))
1498        return res
1499
1500    # Decode single CDDL value, excluding repetitions. May consume 0 to n CBOR objects via the
1501    # iterator.
1502    def _decode_obj(self, it):
1503        my_list = list()
1504        if self.key is not None:
1505            it, it_copy = tee(it)
1506            key_res = self._handle_key(self._iter_next(it_copy))
1507            return it_copy, key_res
1508        if self.tags:
1509            it, it_copy = tee(it)
1510            maybe_tag = next(it_copy)
1511            if isinstance(maybe_tag, CBORTag):
1512                tag_res = self._decode_single_obj(maybe_tag)
1513                return it_copy, tag_res
1514        if self.type == "OTHER" and self.key is None:
1515            return self.my_types[self.value]._decode_full(it)
1516        elif self.type == "GROUP":
1517            my_list = list()
1518            child_it = it
1519            for child in self.value:
1520                child_it, child_obj = child._decode_full(child_it)
1521                if child.key is not None:
1522                    child._add_if(my_list, child_obj, expect_key=True)
1523                else:
1524                    child._add_if(my_list, child_obj)
1525            ret = (child_it, self._construct_obj(my_list))
1526        elif self.type == "UNION":
1527            my_list = list()
1528            child_it = it
1529            found = False
1530            for child in self.value:
1531                try:
1532                    child_it, it_copy = tee(child_it)
1533                    child_it, child_obj = child._decode_full(child_it)
1534                    child._add_if(my_list, child_obj)
1535                    my_list.append(("union_choice", child.var_name()))
1536                    ret = (child_it, self._construct_obj(my_list))
1537                    found = True
1538                    break
1539                except CddlValidationError as c:
1540                    self.errors.append(str(c))
1541                    child_it = it_copy
1542            self._decode_assert(found, "No matches for union: " + str(self))
1543        else:
1544            ret = (it, self._decode_single_obj(self._iter_next(it)))
1545        return ret
1546
1547    # Decode single CDDL value, with repetitions. May consume 0 to n CBOR objects via the iterator.
1548    def _decode_full(self, it):
1549        if self.multi_var_condition():
1550            retvals = []
1551            for i in range(self.min_qty):
1552                it, retval = self._decode_obj(it)
1553                retvals.append(retval if not self.is_unambiguous_repeated() else None)
1554            try:
1555                for i in range(self.max_qty - self.min_qty):
1556                    it, it_copy = tee(it)
1557                    it, retval = self._decode_obj(it)
1558                    retvals.append(retval if not self.is_unambiguous_repeated() else None)
1559            except CddlValidationError as c:
1560                self.errors.append(str(c))
1561                it = it_copy
1562            return it, retvals
1563        else:
1564            ret = self._decode_obj(it)
1565            return ret
1566
1567    # CBOR object => python object
1568    def decode_obj(self, obj):
1569        it = iter([obj])
1570        try:
1571            _, decoded = self._decode_full(it)
1572            self._iter_is_empty(it)
1573        except CddlValidationError as e:
1574            if self.errors:
1575                print("Errors:")
1576                pprint(self.errors)
1577            raise e
1578        return decoded
1579
1580    # YAML => python object
1581    def decode_str_yaml(self, yaml_str, yaml_compat=False):
1582        yaml_obj = yaml_load(yaml_str)
1583        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1584        self.validate_obj(obj)
1585        return self.decode_obj(obj)
1586
1587    # CBOR bytestring => python object
1588    def decode_str(self, cbor_str):
1589        cbor_obj = loads(cbor_str)
1590        return self.decode_obj(cbor_obj)
1591
1592    # Validate CBOR object against CDDL. Exception if not valid.
1593    def validate_obj(self, obj):
1594        self.decode_obj(obj)
1595        return True
1596
1597    # Validate CBOR bytestring against CDDL. Exception if not valid.
1598    def validate_str(self, cbor_str):
1599        cbor_obj = loads(cbor_str)
1600        return self.validate_obj(cbor_obj)
1601
1602    # Convert object from YAML/JSON (with special dicts for bstr, tag etc) to CBOR object
1603    # that cbor2 understands.
1604    def _from_yaml_obj(self, obj):
1605        if isinstance(obj, list):
1606            if len(obj) == 1 and obj[0] == "zcbor_undefined":
1607                return undefined
1608            return [self._from_yaml_obj(elem) for elem in obj]
1609        elif isinstance(obj, dict):
1610            if ["zcbor_bstr"] == list(obj.keys()):
1611                if isinstance(obj["zcbor_bstr"], str):
1612                    bstr = bytes.fromhex(obj["zcbor_bstr"])
1613                else:
1614                    bstr = dumps(self._from_yaml_obj(obj["zcbor_bstr"]))
1615                return bstr
1616            elif ["zcbor_tag", "zcbor_tag_val"] == list(obj.keys()):
1617                return CBORTag(obj["zcbor_tag"], self._from_yaml_obj(obj["zcbor_tag_val"]))
1618            retval = dict()
1619            for key, val in obj.items():
1620                match = fullmatch(r"zcbor_keyval\d+", key)
1621                if match is not None:
1622                    new_key = self._from_yaml_obj(val["key"])
1623                    new_val = self._from_yaml_obj(val["val"])
1624                    if isinstance(new_key, list):
1625                        new_key = tuple(new_key)
1626                    retval[new_key] = new_val
1627                else:
1628                    retval[key] = self._from_yaml_obj(val)
1629            return retval
1630        return obj
1631
1632    # inverse of _from_yaml_obj
1633    def _to_yaml_obj(self, obj):
1634        if isinstance(obj, list) or isinstance(obj, tuple):
1635            return [self._to_yaml_obj(elem) for elem in obj]
1636        elif isinstance(obj, dict):
1637            retval = dict()
1638            i = 0
1639            for key, val in obj.items():
1640                if not isinstance(key, str):
1641                    retval[f"zcbor_keyval{i}"] = {
1642                        "key": self._to_yaml_obj(key), "val": self._to_yaml_obj(val)}
1643                    i += 1
1644                else:
1645                    retval[key] = self._to_yaml_obj(val)
1646            return retval
1647        elif isinstance(obj, bytes):
1648            f = BytesIO(obj)
1649            try:
1650                bstr_obj = self._to_yaml_obj(load(f))
1651            except (CBORDecodeValueError, CBORDecodeEOF):
1652                # failed decoding
1653                bstr_obj = obj.hex()
1654            else:
1655                if f.read(1) != b'':
1656                    # not fully decoded
1657                    bstr_obj = obj.hex()
1658            return {"zcbor_bstr": bstr_obj}
1659        elif isinstance(obj, CBORTag):
1660            return {"zcbor_tag": obj.tag, "zcbor_tag_val": self._to_yaml_obj(obj.value)}
1661        elif obj is undefined:
1662            return ["zcbor_undefined"]
1663        assert not isinstance(obj, bytes)
1664        return obj
1665
1666    # YAML str => CBOR bytestr
1667    def from_yaml(self, yaml_str, yaml_compat=False):
1668        yaml_obj = yaml_load(yaml_str)
1669        obj = self._from_yaml_obj(yaml_obj) if yaml_compat else yaml_obj
1670        self.validate_obj(obj)
1671        return dumps(obj)
1672
1673    # CBOR object => YAML str
1674    def obj_to_yaml(self, obj, yaml_compat=False):
1675        self.validate_obj(obj)
1676        yaml_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1677        return yaml_dump(yaml_obj)
1678
1679    # CBOR bytestring => YAML str
1680    def str_to_yaml(self, cbor_str, yaml_compat=False):
1681        return self.obj_to_yaml(loads(cbor_str), yaml_compat=yaml_compat)
1682
1683    # JSON str => CBOR bytestr
1684    def from_json(self, json_str, yaml_compat=False):
1685        json_obj = json_load(json_str)
1686        obj = self._from_yaml_obj(json_obj) if yaml_compat else json_obj
1687        self.validate_obj(obj)
1688        return dumps(obj)
1689
1690    # CBOR object => JSON str
1691    def obj_to_json(self, obj, yaml_compat=False):
1692        self.validate_obj(obj)
1693        json_obj = self._to_yaml_obj(obj) if yaml_compat else obj
1694        return json_dump(json_obj)
1695
1696    # CBOR bytestring => JSON str
1697    def str_to_json(self, cbor_str, yaml_compat=False):
1698        return self.obj_to_json(loads(cbor_str), yaml_compat=yaml_compat)
1699
1700    # CBOR bytestring => C code (uint8_t array initialization)
1701    def str_to_c_code(self, cbor_str, var_name, columns=0):
1702        arr = ", ".join(f"0x{c:02x}" for c in cbor_str)
1703        if columns:
1704            arr = '\n' + indent("\n".join(wrap(arr, 6 * columns)), '\t') + '\n'
1705        return f'uint8_t {var_name}[] = {{{arr}}};\n'
1706
1707
1708class XcoderTuple(NamedTuple):
1709    body: list
1710    func_name: str
1711    type_name: str
1712
1713
1714class CddlTypes(NamedTuple):
1715    my_types: dict
1716    my_control_groups: dict
1717
1718
1719# Class for generating C code that encode/decodes CBOR and validates it according
1720# to the CDDL.
1721class CodeGenerator(CddlXcoder):
1722    def __init__(self, mode, entry_type_names, default_bit_size, *args, **kwargs):
1723        super(CodeGenerator, self).__init__(*args, **kwargs)
1724        self.mode = mode
1725        self.entry_type_names = entry_type_names
1726        self.default_bit_size = default_bit_size
1727
1728    @classmethod
1729    def from_cddl(cddl_class, mode, *args, **kwargs):
1730        cddl_res = super(CodeGenerator, cddl_class).from_cddl(*args, **kwargs)
1731
1732        # set access prefix (struct access paths) for all the definitions.
1733        for my_type in cddl_res.my_types:
1734            cddl_res.my_types[my_type].set_access_prefix(f"(*{struct_ptr_name(mode)})")
1735
1736        return cddl_res
1737
1738    # Whether this element (an OTHER) refers to an entry type.
1739    def is_entry_type(self):
1740        return (self.type == "OTHER") and (self.value in self.entry_type_names)
1741
1742    # Whether to include a "cbor" variable for this element.
1743    def is_cbor(self):
1744        res = (self.type_name() is not None) and not self.is_entry_type() and (
1745            (self.type != "OTHER") or self.my_types[self.value].is_cbor())
1746        return res
1747
1748    def init_args(self):
1749        return (self.mode, self.entry_type_names, self.default_bit_size, self.default_max_qty)
1750
1751    # Declaration of the "present" variable for this element.
1752    def present_var(self):
1753        return ["bool %s;" % self.present_var_name()]
1754
1755    # Declaration of the "count" variable for this element.
1756    def count_var(self):
1757        return ["size_t %s;" % self.count_var_name()]
1758
1759    # Declaration of the "choice" variable for this element.
1760    def anonymous_choice_var(self):
1761        var = self.enclose(
1762            "enum", [val.enum_var_name(self.all_children_int_disambiguated())
1763                     + "," for val in self.value])
1764        return var
1765
1766    # Declaration of the "choice" variable for this element.
1767    def choice_var(self):
1768        var = self.anonymous_choice_var()
1769        var[-1] += f" {self.choice_var_name()};"
1770        return var
1771
1772    # Declaration of the variables of all children.
1773    def child_declarations(self):
1774        decl = [line for child in self.value for line in child.full_declaration()]
1775        return decl
1776
1777    # Declaration of the variables of all children.
1778    def child_single_declarations(self):
1779        decl = list()
1780        for child in self.value:
1781            if not child.is_unambiguous_repeated():
1782                decl.extend(child.add_var_name(
1783                    child.single_var_type(), anonymous=True))
1784        return decl
1785
1786    def simple_func_condition(self):
1787        if self.single_func_impl_condition():
1788            return True
1789        if self.type == "OTHER" and self.my_types[self.value].simple_func_condition():
1790            return True
1791        return False
1792
1793    # Base name if this element needs to declare a type.
1794    def raw_type_name(self):
1795        return "struct %s" % self.id()
1796
1797    def enum_type_name(self):
1798        return "enum %s" % self.id()
1799
1800    def convert_min_max(self):
1801        defines = {
1802            (2**64) - 1: "UINT64_MAX",
1803            (2**63) - 1: "INT64_MAX",
1804            (2**32) - 1: "UINT32_MAX",
1805            (2**31) - 1: "INT32_MAX",
1806            (2**16) - 1: "UINT16_MAX",
1807            (2**15) - 1: "INT16_MAX",
1808            (2**8) - 1: "UINT8_MAX",
1809            (2**7) - 1: "INT8_MAX",
1810            -(2**63): "INT64_MIN",
1811            -(2**31): "INT32_MIN",
1812            -(2**15): "INT16_MIN",
1813            -(2**7): "INT8_MIN",
1814        }
1815        if self.value in defines:
1816            self.value = defines[self.value]
1817        if self.min_value in defines:
1818            self.min_value = defines[self.min_value]
1819        if self.max_value in defines:
1820            self.max_value = defines[self.max_value]
1821        if self.min_size in defines:
1822            self.min_size = defines[self.min_size]
1823        if self.max_size in defines:
1824            self.max_size = defines[self.max_size]
1825
1826    # The bit width of the integers as represented in code.
1827    def bit_size(self):
1828        bit_size = None
1829        if self.type in ["UINT", "INT", "NINT"]:
1830            assert self.default_bit_size in [32, 64], "The default_bit_size must be 32 or 64."
1831            if self.default_bit_size == 64:
1832                bit_size = 64
1833            else:
1834                bit_size = 32
1835
1836                for v in [self.value or 0, self.max_value or 0, self.min_value or 0]:
1837                    if (type(v) is str):
1838                        if "64" in v:
1839                            bit_size = 64
1840                    elif self.type == "UINT":
1841                        if (v > UINT32_MAX):
1842                            bit_size = 64
1843                    else:
1844                        if (v > INT32_MAX) or (v < INT32_MIN):
1845                            bit_size = 64
1846        return bit_size
1847
1848    def float_type(self):
1849        if self.type != "FLOAT":
1850            return None
1851
1852        max_size = self.max_size or 8
1853
1854        if max_size <= 4:
1855            return "float"
1856        elif max_size == 8:
1857            return "double"
1858        else:
1859            raise TypeError("Floats must have 4 or 8 bytes of precision.")
1860
1861    # Name of the type of this element's actual value variable.
1862    def val_type_name(self):
1863        if self.multi_val_condition():
1864            return self.raw_type_name()
1865
1866        # Will fail runtime if we don't use lambda for type_name()
1867        # pylint: disable=unnecessary-lambda
1868        name = {
1869            "INT": lambda: f"int{self.bit_size()}_t",
1870            "UINT": lambda: f"uint{self.bit_size()}_t",
1871            "NINT": lambda: f"int{self.bit_size()}_t",
1872            "FLOAT": lambda: self.float_type(),
1873            "BSTR": lambda: "struct zcbor_string",
1874            "TSTR": lambda: "struct zcbor_string",
1875            "BOOL": lambda: "bool",
1876            "NIL": lambda: None,
1877            "UNDEF": lambda: None,
1878            "ANY": lambda: None,
1879            "LIST": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1880            "MAP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1881            "GROUP": lambda: self.value[0].type_name() if len(self.value) >= 1 else None,
1882            "UNION": lambda: self.union_type(),
1883            "OTHER": lambda: self.my_types[self.value].type_name(),
1884        }[self.type]()
1885
1886        return name
1887
1888    # Name of the type of for the repeated part of this element.
1889    def repeated_type_name(self):
1890        if self.self_repeated_multi_var_condition():
1891            name = self.raw_type_name()
1892            if self.val_type_name() == name:
1893                name = name + "_"
1894        else:
1895            name = self.val_type_name()
1896        return name
1897
1898    # Name of the type for this element.
1899    def type_name(self):
1900        if self.multi_var_condition():
1901            name = self.raw_type_name()
1902            if self.val_type_name() == name:
1903                name = name + "_"
1904            if self.repeated_type_name() == name:
1905                name = name + "_"
1906        else:
1907            name = self.repeated_type_name()
1908        return name
1909
1910    # Take a multi member type name and create a variable declaration. Make it an array if the
1911    # element is repeated.
1912    def add_var_name(self, var_type, full=False, anonymous=False):
1913        if var_type:
1914            assert (var_type[-1][-1] == "}" or len(var_type) == 1), \
1915                f"Expected single var: {var_type!r}"
1916            if not anonymous or var_type[-1][-1] != "}":
1917                var_type[-1] += " %s%s" % (
1918                    self.var_name(), "[%s]" % self.max_qty if full and self.max_qty != 1 else "")
1919            var_type = add_semicolon(var_type)
1920        return var_type
1921
1922    # The type for this element as a member variable.
1923    def var_type(self):
1924        if not self.multi_val_condition() and self.val_type_name() is not None:
1925            return [self.val_type_name()]
1926        elif self.type == "UNION":
1927            return self.union_type()
1928        return []
1929
1930    # Enclose a list of declarations in a block (struct, union or enum).
1931    def enclose(self, ingress, declaration):
1932        if declaration:
1933            return [f"{ingress} {{"] + [indentation + line for line in declaration] + ["}"]
1934        else:
1935            return []
1936
1937    # Type declaration for unions.
1938    def union_type(self):
1939        declaration = self.enclose("union", self.child_single_declarations())
1940        return declaration
1941
1942    # Declaration of the repeated part of this element.
1943    def repeated_declaration(self):
1944        if self.is_unambiguous_repeated():
1945            return []
1946
1947        var_type = self.var_type()
1948        multi_var = False
1949
1950        decl = self.add_var_name(var_type, anonymous=(self.type == "UNION"))
1951
1952        if self.type in ["LIST", "MAP", "GROUP"]:
1953            decl += self.child_declarations()
1954            multi_var = len(decl) > 1
1955
1956        if self.reduced_key_var_condition():
1957            key_var = self.key.full_declaration()
1958            decl = key_var + decl
1959            multi_var = key_var != []
1960
1961        if self.choice_var_condition():
1962            choice_var = self.choice_var()
1963            decl += choice_var
1964            multi_var = choice_var != []
1965
1966        if self.cbor_var_condition():
1967            cbor_var = self.cbor.full_declaration()
1968            decl += cbor_var
1969            multi_var = cbor_var != []
1970
1971        # if self.type not in ["LIST", "MAP", "GROUP"] or len(self.value) <= 1:
1972            # This assert isn't accurate for value lists with NIL, "UNDEF", or ANY
1973            # members.
1974            # assert multi_var == self.repeated_multi_var_condition(
1975            # ), f"""assert {multi_var} == {self.repeated_multi_var_condition()}
1976            # type: {self.type}
1977            # decl {decl}
1978            # self.key_var_condition() is {self.key_var_condition()}
1979            # self.key is {self.key}
1980            # self.key.full_declaration() is {self.key.full_declaration()}
1981            # self.cbor_var_condition() is {self.cbor_var_condition()}
1982            # self.choice_var_condition() is {self.choice_var_condition()}
1983            # self.value is {self.value}"""
1984        return decl
1985
1986    # Declaration of the full type for this element.
1987    def full_declaration(self):
1988        multi_var = False
1989
1990        if self.is_unambiguous():
1991            return []
1992
1993        if self.multi_var_condition():
1994            if self.is_unambiguous_repeated():
1995                decl = []
1996            else:
1997                decl = self.add_var_name(
1998                    [self.repeated_type_name()]
1999                    if self.repeated_type_name() is not None else [], full=True)
2000        else:
2001            decl = self.repeated_declaration()
2002
2003        if self.count_var_condition():
2004            count_var = self.count_var()
2005            decl += count_var
2006            multi_var = count_var != []
2007
2008        if self.present_var_condition():
2009            present_var = self.present_var()
2010            decl += present_var
2011            multi_var = present_var != []
2012
2013        assert multi_var == self.multi_var_condition()
2014
2015        return decl
2016
2017    # Return the type definition of this element. If there are multiple variables, wrap them in a
2018    # struct so the function always returns a single type with no name. If full is False, only
2019    # repeated part is used.
2020    def single_var_type(self, full=True):
2021        if full and self.multi_member():
2022            return self.enclose("struct", self.full_declaration())
2023        elif not full and self.repeated_multi_var_condition():
2024            return self.enclose("struct", self.repeated_declaration())
2025        else:
2026            return self.var_type()
2027
2028    # Return the type definition of this element, and all its children + key +
2029    # cbor.
2030    def type_def(self):
2031        ret_val = []
2032        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2033            ret_val.extend(
2034                [elem for typedef in [
2035                    child.type_def() for child in self.value] for elem in typedef])
2036        if self.bits:
2037            ret_val.extend(self.my_control_groups[self.bits].type_def_bits())
2038        if self.cbor_var_condition():
2039            ret_val.extend(self.cbor.type_def())
2040        if self.reduced_key_var_condition():
2041            ret_val.extend(self.key.type_def())
2042        if self.type == "OTHER":
2043            ret_val.extend(self.my_types[self.value].type_def())
2044        if self.repeated_type_def_condition():
2045            type_def_list = self.single_var_type(full=False)
2046            if type_def_list:
2047                ret_val.extend([(self.single_var_type(full=False), self.repeated_type_name())])
2048        if self.type_def_condition():
2049            type_def_list = self.single_var_type()
2050            if type_def_list:
2051                ret_val.extend([(self.single_var_type(), self.type_name())])
2052        return ret_val
2053
2054    def type_def_bits(self):
2055        tdef = self.anonymous_choice_var()
2056        return [(tdef, self.enum_type_name())]
2057
2058    def float_prefix(self):
2059        if self.type != "FLOAT":
2060            return ""
2061
2062        min_size = self.min_size or 2
2063        max_size = self.max_size or 8
2064
2065        if max_size == 2:
2066            return "float16"
2067        elif min_size == 2 and max_size == 4:
2068            return "float16_32" if self.mode == "decode" else "float32"
2069        if min_size == 4 and max_size == 4:
2070            return "float32"
2071        elif min_size == 4 and max_size == 8:
2072            return "float32_64" if self.mode == "decode" else "float64"
2073        elif min_size == 8 and max_size == 8:
2074            return "float64"
2075        elif min_size <= 4 and max_size == 8:
2076            return "float" if self.mode == "decode" else "float64"
2077        else:
2078            raise TypeError("Floats must have 2, 4 or 8 bytes of precision.")
2079
2080    def single_func_prim_prefix(self):
2081        if self.type == "OTHER":
2082            return self.my_types[self.value].single_func_prim_prefix()
2083        return ({
2084            "INT": f"zcbor_int{self.bit_size()}",
2085            "UINT": f"zcbor_uint{self.bit_size()}",
2086            "NINT": f"zcbor_int{self.bit_size()}",
2087            "FLOAT": f"zcbor_{self.float_prefix()}",
2088            "BSTR": f"zcbor_bstr",
2089            "TSTR": f"zcbor_tstr",
2090            "BOOL": f"zcbor_bool",
2091            "NIL": f"zcbor_nil",
2092            "UNDEF": f"zcbor_undefined",
2093            "ANY": f"zcbor_any",
2094        }[self.type])
2095
2096    # Name of the encoder/decoder function for this element.
2097    def xcode_func_name(self):
2098        return f"{self.mode}{self.var_name(with_prefix=True)}"
2099
2100    # Name of the encoder/decoder function for the repeated part of this element.
2101    def repeated_xcode_func_name(self):
2102        return f"{self.mode}_repeated{self.var_name(with_prefix=True)}"
2103
2104    def single_func_prim_name(self, union_int=None):
2105        """Function name for xcoding this type, when it is a primitive type"""
2106        func_prefix = self.single_func_prim_prefix()
2107        if self.mode == "decode":
2108            if self.type == "ANY":
2109                func = "zcbor_any_skip"
2110            elif not self.is_unambiguous_value():
2111                func = f"{func_prefix}_decode"
2112            elif not union_int:
2113                func = f"{func_prefix}_expect"
2114            elif union_int == "EXPECT":
2115                func = f"{func_prefix}_expect_union"
2116            elif union_int == "DROP":
2117                return None
2118        else:
2119            if self.type == "ANY":
2120                func = "zcbor_nil_put"
2121            elif (not self.is_unambiguous_value()) or self.type in ["TSTR", "BSTR"]:
2122                func = f"{func_prefix}_encode"
2123            else:
2124                func = f"{func_prefix}_put"
2125        return func
2126
2127    def single_func_prim(self, access, union_int=None, ptr_result=False):
2128        """Return the function name and arguments to call to encode/decode this element. Only used
2129        when this element DOESN'T define its own encoder/decoder function (when it's a primitive
2130        type, for which functions already exist, or when the function is defined elsewhere
2131        ("OTHER"))
2132        """
2133        assert self.type not in ["LIST", "MAP"], "Must have wrapper function for list or map."
2134
2135        if self.type == "OTHER":
2136            return self.my_types[self.value].single_func(access, union_int)
2137
2138        func_name = self.single_func_prim_name(union_int)
2139        if func_name is None:
2140            return (None, None)
2141
2142        if self.type in ["NIL", "UNDEF", "ANY"]:
2143            arg = "NULL"
2144        elif not self.is_unambiguous_value():
2145            arg = deref_if_not_null(access)
2146        elif self.type in ["BSTR", "TSTR"]:
2147            arg = tmp_str_or_null(self.value)
2148        elif self.type in ["UINT", "INT", "NINT", "FLOAT", "BOOL"]:
2149            # Make False and True lower case
2150            value = str(self.value).lower() if self.type == "BOOL" else str(self.value)
2151            arg = ("(void *) " if ptr_result else "") + value
2152        else:
2153            assert False, "Should not come here."
2154
2155        min_val = None
2156        max_val = None
2157
2158        if self.value is None:
2159            if self.type in ["BSTR", "TSTR"]:
2160                min_val = self.min_size
2161                max_val = self.max_size
2162            else:
2163                min_val = self.min_value
2164                max_val = self.max_value
2165        return (func_name, arg)
2166
2167    # Return the function name and arguments to call to encode/decode this element.
2168    def single_func(self, access=None, union_int=None):
2169        if self.single_func_impl_condition():
2170            return (self.xcode_func_name(), deref_if_not_null(access or self.var_access()))
2171        else:
2172            return self.single_func_prim(access or self.val_access(), union_int)
2173
2174    # Return the function name and arguments to call to encode/decode the repeated
2175    # part of this element.
2176    def repeated_single_func(self, ptr_result=False):
2177        if self.repeated_single_func_impl_condition():
2178            return (self.repeated_xcode_func_name(), deref_if_not_null(self.repeated_val_access()))
2179        else:
2180            return self.single_func_prim(self.repeated_val_access(), ptr_result=ptr_result)
2181
2182    def has_backup(self):
2183        return (self.cbor_var_condition() or self.type in ["LIST", "MAP", "UNION"])
2184
2185    def num_backups(self):
2186        total = 0
2187        if self.key:
2188            total += self.key.num_backups()
2189        if self.cbor_var_condition():
2190            total += self.cbor.num_backups()
2191        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2192            total += max([child.num_backups() for child in self.value] + [0])
2193        if self.type == "OTHER":
2194            total += self.my_types[self.value].num_backups()
2195        if self.has_backup():
2196            total += 1
2197        return total
2198
2199    # Return a number indicating how many other elements this element depends on. Used putting
2200    # functions and typedefs in the right order.
2201    def depends_on(self):
2202        ret_vals = [1]
2203
2204        if not self.dependsOnCall:
2205            self.dependsOnCall = True
2206            if self.cbor_var_condition():
2207                ret_vals.append(self.cbor.depends_on())
2208            if self.key:
2209                ret_vals.append(self.key.depends_on())
2210            if self.type == "OTHER":
2211                ret_vals.append(1 + self.my_types[self.value].depends_on())
2212            if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2213                ret_vals.extend(child.depends_on() for child in self.value)
2214            self.dependsOnCall = False
2215
2216        return max(ret_vals)
2217
2218    # Make a string from the list returned by single_func_prim()
2219    def xcode_single_func_prim(self, union_int=None):
2220        return xcode_statement(*self.single_func_prim(self.val_access(), union_int))
2221
2222    # Recursively sum the total minimum and maximum element count for this
2223    # element.
2224    def list_counts(self):
2225        retval = ({
2226            "INT": lambda: (self.min_qty, self.max_qty),
2227            "UINT": lambda: (self.min_qty, self.max_qty),
2228            "NINT": lambda: (self.min_qty, self.max_qty),
2229            "FLOAT": lambda: (self.min_qty, self.max_qty),
2230            "BSTR": lambda: (self.min_qty, self.max_qty),
2231            "TSTR": lambda: (self.min_qty, self.max_qty),
2232            "BOOL": lambda: (self.min_qty, self.max_qty),
2233            "NIL": lambda: (self.min_qty, self.max_qty),
2234            "UNDEF": lambda: (self.min_qty, self.max_qty),
2235            "ANY": lambda: (self.min_qty, self.max_qty),
2236            # Lists are their own element
2237            "LIST": lambda: (self.min_qty, self.max_qty),
2238            # Maps are their own element
2239            "MAP": lambda: (self.min_qty, self.max_qty),
2240            "GROUP": lambda: (self.min_qty * sum((child.list_counts()[0] for child in self.value)),
2241                              self.max_qty * sum((child.list_counts()[1] for child in self.value))),
2242            "UNION": lambda: (self.min_qty * min((child.list_counts()[0] for child in self.value)),
2243                              self.max_qty * max((child.list_counts()[1] for child in self.value))),
2244            "OTHER": lambda: (self.min_qty * self.my_types[self.value].list_counts()[0],
2245                              self.max_qty * self.my_types[self.value].list_counts()[1]),
2246        }[self.type]())
2247        return retval
2248
2249    # Return the full code needed to encode/decode a "LIST" or "MAP" element with children.
2250    def xcode_list(self):
2251        start_func = f"zcbor_{self.type.lower()}_start_{self.mode}"
2252        end_func = f"zcbor_{self.type.lower()}_end_{self.mode}"
2253        end_func_force = f"zcbor_list_map_end_force_{self.mode}"
2254        assert start_func in [
2255            "zcbor_list_start_decode", "zcbor_list_start_encode",
2256            "zcbor_map_start_decode", "zcbor_map_start_encode"]
2257        assert end_func in [
2258            "zcbor_list_end_decode", "zcbor_list_end_encode",
2259            "zcbor_map_end_decode", "zcbor_map_end_encode"]
2260        assert self.type in ["LIST", "MAP"], \
2261            "Expected LIST or MAP type, was %s." % self.type
2262        _, max_counts = zip(
2263            *(child.list_counts() for child in self.value)) if self.value else ((0,), (0,))
2264        count_arg = f', {str(sum(max_counts))}' if self.mode == 'encode' else ''
2265        with_children = "(%s && ((%s) || (%s, false)) && %s)" % (
2266            f"{start_func}(state{count_arg})",
2267            f"{newl_ind}&& ".join(child.full_xcode() for child in self.value),
2268            f"{end_func_force}(state)",
2269            f"{end_func}(state{count_arg})")
2270        without_children = "(%s && %s)" % (
2271            f"{start_func}(state{count_arg})",
2272            f"{end_func}(state{count_arg})")
2273        return with_children if len(self.value) > 0 else without_children
2274
2275    # Return the full code needed to encode/decode a "GROUP" element's children.
2276    def xcode_group(self, union_int=None):
2277        assert self.type in ["GROUP"], "Expected GROUP type."
2278        return "(%s)" % (newl_ind + "&& ").join(
2279            [self.value[0].full_xcode(union_int)]
2280            + [child.full_xcode() for child in self.value[1:]])
2281
2282    # Return the full code needed to encode/decode a "UNION" element's children.
2283    def xcode_union(self):
2284        assert self.type in ["UNION"], "Expected UNION type."
2285        if self.mode == "decode":
2286            if self.all_children_int_disambiguated():
2287                lines = []
2288                lines.extend(
2289                    ["((%s == %s) && (%s))" %
2290                        (self.choice_var_access(), child.enum_var_name(),
2291                            child.full_xcode(union_int="DROP"))
2292                        for child in self.value])
2293                bit_size = self.value[0].bit_size()
2294                func = f"zcbor_int_{self.mode}"
2295                return "((%s) && (%s))" % (
2296                    f"({func}(state, &{self.choice_var_access()}, "
2297                    + f"sizeof({self.choice_var_access()})))",
2298                    "((" + f"{newl_ind}|| ".join(lines)
2299                         + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_VALUE), false))",)
2300
2301            child_values = ["(%s && ((%s = %s), true))" %
2302                            (child.full_xcode(
2303                                union_int="EXPECT" if child.is_int_disambiguated() else None),
2304                                self.choice_var_access(), child.enum_var_name())
2305                            for child in self.value]
2306
2307            # Reset state for all but the first child.
2308            for i in range(1, len(child_values)):
2309                if ((not self.value[i].is_int_disambiguated())
2310                        and (self.value[i].simple_func_condition()
2311                             or self.value[i - 1].simple_func_condition())):
2312                    child_values[i] = f"(zcbor_union_elem_code(state) && {child_values[i]})"
2313
2314            return "(%s && (int_res = (%s), %s, int_res))" \
2315                % ("zcbor_union_start_code(state)",
2316                   f"{newl_ind}|| ".join(child_values),
2317                   "zcbor_union_end_code(state)")
2318        else:
2319            return ternary_if_chain(
2320                self.choice_var_access(),
2321                [child.enum_var_name() for child in self.value],
2322                [child.full_xcode() for child in self.value])
2323
2324    def xcode_bstr(self):
2325        if self.cbor and not self.cbor.is_entry_type():
2326            access_arg = f', {deref_if_not_null(self.val_access())}' if self.mode == 'decode' \
2327                else ''
2328            res_arg = f', &tmp_str' if self.mode == 'encode' \
2329                else ''
2330            xcode_cbor = "(%s)" % ((newl_ind + "&& ").join(
2331                [f"zcbor_bstr_start_{self.mode}(state{access_arg})",
2332                 f"(int_res = ({self.cbor.full_xcode()}), "
2333                 f"zcbor_bstr_end_{self.mode}(state{res_arg}), int_res)"]))
2334            if self.mode == "decode" or self.is_unambiguous():
2335                return xcode_cbor
2336            else:
2337                return f"({self.val_access()}.value " \
2338                    f"? (memcpy(&tmp_str, &{self.val_access()}, sizeof(tmp_str)), " \
2339                    f"{self.xcode_single_func_prim()}) : ({xcode_cbor}))"
2340        return self.xcode_single_func_prim()
2341
2342    def xcode_tags(self):
2343        return [f"zcbor_tag_{self.mode if (self.mode == 'encode') else 'expect'}(state, {tag})"
2344                for tag in self.tags]
2345
2346    # Appends ULL or LL if a value exceeding 32-bits is used
2347    def value_suffix(self, value):
2348        if type(value) is str:
2349            return ""
2350        if self.type == "INT" or self.type == "NINT":
2351            if value > INT32_MAX or value <= INT32_MIN:
2352                return "LL"
2353        elif self.type == "UINT":
2354            if value > UINT32_MAX:
2355                return "ULL"
2356
2357        return ""
2358
2359    def range_checks(self, access):
2360        if self.type != "OTHER" and self.value is not None:
2361            return []
2362
2363        range_checks = []
2364
2365        if self.type in ["INT", "UINT", "NINT", "FLOAT", "BOOL"]:
2366            if self.min_value is not None:
2367                range_checks.append(f"({access} >= {self.min_value}"
2368                                    f"{self.value_suffix(self.min_value)})")
2369            if self.max_value is not None:
2370                range_checks.append(f"({access} <= {self.max_value}"
2371                                    f"{self.value_suffix(self.max_value)})")
2372            if self.bits:
2373                range_checks.append(
2374                    f"!({access} & ~("
2375                    + ' | '.join([f'(1 << {c.enum_var_name()})'
2376                                 for c in self.my_control_groups[self.bits].value])
2377                    + "))")
2378        elif self.type in ["BSTR", "TSTR"]:
2379            if self.min_size is not None:
2380                range_checks.append(f"({access}.len >= {self.min_size})")
2381            if self.max_size is not None:
2382                range_checks.append(f"({access}.len <= {self.max_size})")
2383        elif self.type == "OTHER":
2384            range_checks.extend(self.my_types[self.value].range_checks(access))
2385
2386        if range_checks:
2387            range_checks[0] = "((" + range_checks[0]
2388            range_checks[-1] = range_checks[-1] \
2389                + ") || (zcbor_error(state, ZCBOR_ERR_WRONG_RANGE), false))"
2390
2391        return range_checks
2392
2393    # Return the full code needed to encode/decode this element, including children,
2394    # key and cbor, excluding repetitions.
2395    def repeated_xcode(self, union_int=None):
2396        range_checks = self.range_checks(self.val_access())
2397        xcoder = {
2398            "INT": self.xcode_single_func_prim,
2399            "UINT": lambda: self.xcode_single_func_prim(union_int),
2400            "NINT": lambda: self.xcode_single_func_prim(union_int),
2401            "FLOAT": self.xcode_single_func_prim,
2402            "BSTR": self.xcode_bstr,
2403            "TSTR": self.xcode_single_func_prim,
2404            "BOOL": self.xcode_single_func_prim,
2405            "NIL": self.xcode_single_func_prim,
2406            "UNDEF": self.xcode_single_func_prim,
2407            "ANY": self.xcode_single_func_prim,
2408            "LIST": self.xcode_list,
2409            "MAP": self.xcode_list,
2410            "GROUP": lambda: self.xcode_group(union_int),
2411            "UNION": self.xcode_union,
2412            "OTHER": lambda: self.xcode_single_func_prim(union_int),
2413        }[self.type]
2414        xcoders = []
2415        if self.key:
2416            xcoders.append(self.key.full_xcode(union_int))
2417        if self.tags:
2418            xcoders.extend(self.xcode_tags())
2419        if self.mode == "decode":
2420            xcoders.append(xcoder())
2421            xcoders.extend(range_checks)
2422        elif self.type == "BSTR" and self.cbor:
2423            xcoders.append(xcoder())
2424            xcoders.extend(self.range_checks("tmp_str"))
2425        else:
2426            xcoders.extend(range_checks)
2427            xcoders.append(xcoder())
2428
2429        return "(%s)" % ((newl_ind + "&& ").join(xcoders),)
2430
2431    # Code for the size of the repeated part of this element.
2432    def result_len(self):
2433        if self.repeated_type_name() is None or self.is_unambiguous_repeated():
2434            return "0"
2435        else:
2436            return "sizeof(%s)" % self.repeated_type_name()
2437
2438    # Return the full code needed to encode/decode this element, including children,
2439    # key, cbor, and repetitions.
2440    def full_xcode(self, union_int=None):
2441        if self.present_var_condition():
2442            func, *arguments = self.repeated_single_func(ptr_result=True)
2443            return (
2444                f"zcbor_present_{self.mode}(&(%s), (zcbor_{self.mode}r_t *)%s, %s)" %
2445                (self.present_var_access(),
2446                 func,
2447                 xcode_args(*arguments),))
2448        elif self.count_var_condition():
2449            func, *arguments = self.repeated_single_func(ptr_result=True)
2450            minmax = "_minmax" if self.mode == "encode" else ""
2451            mode = self.mode
2452            return (
2453                f"zcbor_multi_{mode}{minmax}(%s, %s, &%s, (zcbor_{mode}r_t *)%s, %s, %s)" %
2454                (self.min_qty,
2455                 self.max_qty,
2456                 self.count_var_access(),
2457                 func,
2458                 xcode_args(*arguments),
2459                 self.result_len()))
2460        else:
2461            return self.repeated_xcode(union_int)
2462
2463    # Return the body of the encoder/decoder function for this element.
2464    def xcode(self):
2465        return self.full_xcode()
2466
2467    # Recursively return a list of the bodies of the encoder/decoder functions for
2468    # this element and its children + key + cbor.
2469    def xcoders(self):
2470        if self.type in ["LIST", "MAP", "GROUP", "UNION"]:
2471            for child in self.value:
2472                for xcoder in child.xcoders():
2473                    yield xcoder
2474        if self.cbor:
2475            for xcoder in self.cbor.xcoders():
2476                yield xcoder
2477        if self.key:
2478            for xcoder in self.key.xcoders():
2479                yield xcoder
2480        if self.type == "OTHER" and self.value not in self.entry_type_names:
2481            for xcoder in self.my_types[self.value].xcoders():
2482                yield xcoder
2483        if self.repeated_single_func_impl_condition():
2484            yield XcoderTuple(
2485                self.repeated_xcode(), self.repeated_xcode_func_name(), self.repeated_type_name())
2486        if (self.single_func_impl_condition()):
2487            xcode_body = self.xcode()
2488            yield XcoderTuple(xcode_body, self.xcode_func_name(), self.type_name())
2489
2490    def public_xcode_func_sig(self):
2491        type_name = self.type_name()
2492        return f"""
2493int cbor_{self.xcode_func_name()}(
2494		{"const " if self.mode == "decode" else ""}uint8_t *payload, size_t payload_len,
2495		{"" if self.mode == "decode" else "const "}{type_name
2496            if struct_ptr_name(self.mode) in self.full_xcode() else "void"} *{
2497            struct_ptr_name(self.mode)},
2498		{"size_t *payload_len_out"})"""
2499
2500
2501class CodeRenderer():
2502    def __init__(self, entry_types, modes, print_time, default_max_qty, git_sha='', file_header=''):
2503        self.entry_types = entry_types
2504        self.print_time = print_time
2505        self.default_max_qty = default_max_qty
2506
2507        self.sorted_types = dict()
2508        self.functions = dict()
2509        self.type_defs = dict()
2510
2511        # Sort type definitions so the typedefs will come in the correct order in the header file
2512        # and the function in the correct order in the c file.
2513        for mode in modes:
2514            self.sorted_types[mode] = list(sorted(
2515                self.entry_types[mode], key=lambda _type: _type.depends_on(), reverse=False))
2516
2517            self.functions[mode] = self.unique_funcs(mode)
2518            self.functions[mode] = self.used_funcs(mode)
2519            self.type_defs[mode] = self.unique_types(mode)
2520
2521        self.version = __version__
2522
2523        if git_sha:
2524            self.version += f'-{git_sha}'
2525
2526        self.file_header = file_header.strip() + "\n\n" if file_header.strip() else ""
2527        self.file_header += f"""Generated using zcbor version {self.version}
2528https://github.com/NordicSemiconductor/zcbor{'''
2529at: ''' + datetime.now().strftime('%Y-%m-%d %H:%M:%S') if self.print_time else ''}
2530Generated with a --default-max-qty of {self.default_max_qty}"""
2531
2532    def header_guard(self, file_name):
2533        return path.basename(file_name).replace(".", "_").replace("-", "_").upper() + "__"
2534
2535    # Return a list of typedefs for all defined types, with duplicate typedefs
2536    # removed.
2537    def unique_types(self, mode):
2538        type_names = {}
2539        out_types = []
2540        for mtype in self.sorted_types[mode]:
2541            for type_def in mtype.type_def():
2542                type_name = type_def[1]
2543                if type_name not in type_names.keys():
2544                    type_names[type_name] = type_def[0]
2545                    out_types.append(type_def)
2546                else:
2547                    assert (
2548                        ''.join(type_names[type_name]) == ''.join(type_def[0])), \
2549                        "Two elements share the type name %s, but their implementations are not " \
2550                        + "identical. Please change one or both names. One of them is %s" \
2551                        % (type_name, pformat(mtype.type_def()))
2552        return out_types
2553
2554    # Return a list of encoder/decoder functions for all defined types, with duplicate
2555    # functions removed.
2556    def unique_funcs(self, mode):
2557        func_names = {}
2558        out_types = []
2559        for mtype in self.sorted_types[mode]:
2560            xcoders = list(mtype.xcoders())
2561            for funcType in xcoders:
2562                func_xcode = funcType[0]
2563                func_name = funcType[1]
2564                if func_name not in func_names.keys():
2565                    func_names[func_name] = funcType
2566                    out_types.append(funcType)
2567                elif func_name in func_names.keys():
2568                    assert func_names[func_name][0] == func_xcode,\
2569                        ("Two elements share the function name %s, but their implementations are "
2570                            + "not identical. Please change one or both names.\n\n%s\n\n%s") % \
2571                        (func_name, func_names[func_name][0], func_xcode)
2572
2573        return out_types
2574
2575    # Return a list of encoder/decoder functions for all defined types, with unused
2576    # functions removed.
2577    def used_funcs(self, mode):
2578        mod_entry_types = [
2579            XcoderTuple(
2580                func_type.xcode(),
2581                func_type.xcode_func_name(),
2582                func_type.type_name()) for func_type in self.entry_types[mode]]
2583        out_types = [func_type for func_type in mod_entry_types]
2584        full_code = "".join([func_type[0] for func_type in mod_entry_types])
2585        for func_type in reversed(self.functions[mode]):
2586            func_name = func_type[1]
2587            if func_type not in mod_entry_types and search(r"%s\W" % func_name, full_code):
2588                full_code += func_type[0]
2589                out_types.append(func_type)
2590        return list(reversed(out_types))
2591
2592    # Render a single decoding function with signature and body.
2593    def render_forward_declaration(self, xcoder, mode):
2594        return f"""
2595static bool {xcoder.func_name}(zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2596            xcoder.type_name
2597            if struct_ptr_name(mode) in xcoder.body else "void"} *{struct_ptr_name(mode)});
2598            """.strip()
2599
2600    def render_function(self, xcoder, mode):
2601        body = xcoder.body
2602        return f"""
2603static bool {xcoder.func_name}(
2604		zcbor_state_t *state, {"" if mode == "decode" else "const "}{
2605            xcoder.type_name
2606            if struct_ptr_name(mode) in body else "void"} *{struct_ptr_name(mode)})
2607{{
2608	zcbor_print("%s\\r\\n", __func__);
2609	{"struct zcbor_string tmp_str;" if "tmp_str" in body else ""}
2610	{"bool int_res;" if "int_res" in body else ""}
2611
2612	bool tmp_result = ({ body });
2613
2614	if (!tmp_result)
2615		zcbor_trace();
2616
2617	return tmp_result;
2618}}""".replace("	\n", "")  # call replace() to remove empty lines.
2619
2620    # Render a single entry function (API function) with signature and body.
2621    def render_entry_function(self, xcoder, mode):
2622        func_name, func_arg = (xcoder.xcode_func_name(), struct_ptr_name(mode))
2623        return f"""
2624{xcoder.public_xcode_func_sig()}
2625{{
2626	zcbor_state_t states[{xcoder.num_backups() + 2}];
2627
2628	zcbor_new_state(states, sizeof(states) / sizeof(zcbor_state_t), payload, payload_len, {
2629        xcoder.list_counts()[1]});
2630
2631	bool ret = {func_name}(states, {func_arg});
2632
2633	if (ret && (payload_len_out != NULL)) {{
2634		*payload_len_out = MIN(payload_len,
2635				(size_t)states[0].payload - (size_t)payload);
2636	}}
2637
2638	if (!ret) {{
2639		int err = zcbor_pop_error(states);
2640
2641		zcbor_print("Return error: %d\\r\\n", err);
2642		return (err == ZCBOR_SUCCESS) ? ZCBOR_ERR_UNKNOWN : err;
2643	}}
2644	return ZCBOR_SUCCESS;
2645}}"""
2646
2647    def render_file_header(self, line_prefix):
2648        lp = line_prefix
2649        return (f"\n{lp} " + self.file_header.replace("\n", f"\n{lp} ")).replace(" \n", "\n")
2650
2651    # Render the entire generated C file contents.
2652    def render_c_file(self, header_file_name, mode):
2653        return f"""/*{self.render_file_header(" *")}
2654 */
2655
2656#include <stdint.h>
2657#include <stdbool.h>
2658#include <stddef.h>
2659#include <string.h>
2660#include "zcbor_{mode}.h"
2661#include "{header_file_name}"
2662
2663#if DEFAULT_MAX_QTY != {self.default_max_qty}
2664#error "The type file was generated with a different default_max_qty than this file"
2665#endif
2666
2667{linesep.join([self.render_forward_declaration(xcoder, mode) for xcoder in self.functions[mode]])}
2668
2669{linesep.join([self.render_function(xcoder, mode) for xcoder in self.functions[mode]])}
2670
2671{linesep.join([self.render_entry_function(xcoder, mode) for xcoder in self.entry_types[mode]])}
2672"""
2673
2674    # Render the entire generated header file contents.
2675    def render_h_file(self, type_def_file, header_guard, mode):
2676        return \
2677            f"""/*{self.render_file_header(" *")}
2678 */
2679
2680#ifndef {header_guard}
2681#define {header_guard}
2682
2683#include <stdint.h>
2684#include <stdbool.h>
2685#include <stddef.h>
2686#include <string.h>
2687#include "{type_def_file}"
2688
2689#ifdef __cplusplus
2690extern "C" {{
2691#endif
2692
2693#if DEFAULT_MAX_QTY != {self.default_max_qty}
2694#error "The type file was generated with a different default_max_qty than this file"
2695#endif
2696
2697{(linesep+linesep).join(
2698    [f"{xcoder.public_xcode_func_sig()};" for xcoder in self.entry_types[mode]])}
2699
2700
2701#ifdef __cplusplus
2702}}
2703#endif
2704
2705#endif /* {header_guard} */
2706"""
2707
2708    def render_type_file(self, header_guard, mode):
2709        body = (
2710            linesep + linesep).join(
2711                [f"{typedef[1]} {{{linesep}{linesep.join(typedef[0][1:])};"
2712                    for typedef in self.type_defs[mode]])
2713        return \
2714            f"""/*{self.render_file_header(" *")}
2715 */
2716
2717#ifndef {header_guard}
2718#define {header_guard}
2719
2720#include <stdint.h>
2721#include <stdbool.h>
2722#include <stddef.h>
2723{'#include <zcbor_common.h>' if "struct zcbor_string" in body else ""}
2724
2725#ifdef __cplusplus
2726extern "C" {{
2727#endif
2728
2729/** Which value for --default-max-qty this file was created with.
2730 *
2731 *  The define is used in the other generated file to do a build-time
2732 *  compatibility check.
2733 *
2734 *  See `zcbor --help` for more information about --default-max-qty
2735 */
2736#define DEFAULT_MAX_QTY {self.default_max_qty}
2737
2738{body}
2739
2740#ifdef __cplusplus
2741}}
2742#endif
2743
2744#endif /* {header_guard} */
2745"""
2746
2747    def render_cmake_file(self, target_name, h_files, c_files, type_file,
2748                          output_c_dir, output_h_dir, cmake_dir):
2749        include_dirs = sorted(set(((Path(output_h_dir)),
2750                                  (Path(type_file.name).parent),
2751                                  *((Path(h.name).parent) for h in h_files.values()))))
2752
2753        def relativify(p):
2754            return "${CMAKE_CURRENT_LIST_DIR}/" + path.relpath(Path(p), cmake_dir)
2755        return \
2756            f"""\
2757#{self.render_file_header("#")}
2758#
2759
2760add_library({target_name})
2761target_sources({target_name} PRIVATE
2762    {relativify(Path(output_c_dir, "zcbor_decode.c"))}
2763    {relativify(Path(output_c_dir, "zcbor_encode.c"))}
2764    {relativify(Path(output_c_dir, "zcbor_common.c"))}
2765    {(linesep + "    ").join(((relativify(c.name)) for c in c_files.values()))}
2766    )
2767target_include_directories({target_name} PUBLIC
2768    {(linesep + "    ").join(((relativify(f) for f in include_dirs)))}
2769    )
2770"""
2771
2772    def render(self, modes, h_files, c_files, type_file, include_prefix, cmake_file=None,
2773               output_c_dir=None, output_h_dir=None):
2774        for mode in modes:
2775            h_name = Path(include_prefix, Path(h_files[mode].name).name)
2776
2777            # Create and populate the generated c and h file.
2778            makedirs("./" + path.dirname(c_files[mode].name), exist_ok=True)
2779
2780            type_def_name = Path(include_prefix, Path(type_file.name).name)
2781
2782            print("Writing to " + c_files[mode].name)
2783            c_files[mode].write(self.render_c_file(h_name, mode))
2784
2785            print("Writing to " + h_files[mode].name)
2786            h_files[mode].write(self.render_h_file(
2787                type_def_name,
2788                self.header_guard(h_files[mode].name), mode))
2789
2790        print("Writing to " + type_file.name)
2791        type_file.write(self.render_type_file(self.header_guard(type_file.name), mode))
2792
2793        if cmake_file:
2794            print("Writing to " + cmake_file.name)
2795            cmake_file.write(self.render_cmake_file(
2796                Path(cmake_file.name).stem, h_files, c_files, type_file,
2797                output_c_dir, output_h_dir, Path(cmake_file.name).absolute().parent))
2798
2799
2800def int_or_str(arg):
2801    try:
2802        return int(arg)
2803    except ValueError:
2804        # print(arg)
2805        if match(r"\A\w+\Z", arg) is not None:
2806            return arg
2807    raise ArgumentTypeError(
2808        "Argument must be an integer or a string with only letters, numbers, or '_'.")
2809
2810
2811def parse_args():
2812
2813    parent_parser = ArgumentParser(add_help=False)
2814
2815    parent_parser.add_argument(
2816        "-c", "--cddl", required=True, type=FileType('r'), action="append",
2817        help="""Path to one or more input CDDL file(s). Passing multiple files is equivalent to
2818concatenating them.""")
2819    parent_parser.add_argument(
2820        "--no-prelude", required=False, action="store_true", default=False,
2821        help=f"""Exclude the standard CDDL prelude from the build. The prelude can be viewed at
2822{PRELUDE_path.relative_to(P_REPO_ROOT)} in the repo, or together with the script.""")
2823    parent_parser.add_argument(
2824        "-v", "--verbose", required=False, action="store_true", default=False,
2825        help="Print more information while parsing CDDL and generating code.")
2826
2827    parser = ArgumentParser(
2828        description='''Parse a CDDL file and validate/convert between YAML, JSON, and CBOR.
2829Can also generate C code for validation/encoding/decoding of CBOR.''')
2830
2831    parser.add_argument(
2832        "--version", action="version", version=f"zcbor {__version__}")
2833
2834    subparsers = parser.add_subparsers()
2835    code_parser = subparsers.add_parser(
2836        "code", description='''Parse a CDDL file and produce C code that validates and xcodes CBOR.
2837The output from this script is a C file and a header file. The header file
2838contains typedefs for all the types specified in the cddl input file, as well
2839as declarations to xcode functions for the types designated as entry types when
2840running the script. The c file contains all the code for decoding and validating
2841the types in the CDDL input file. All types are validated as they are xcoded.
2842
2843Where a `bstr .cbor <Type>` is specified in the CDDL, AND the Type is an entry
2844type, the xcoder will not xcode the string, only provide a pointer into the
2845payload buffer. This is useful to reduce the size of typedefs, or to break up
2846decoding. Using this mechanism is necessary when the CDDL contains self-
2847referencing types, since the C type cannot be self referencing.
2848
2849This script requires 'regex' for lookaround functionality not present in 're'.''',
2850        formatter_class=RawDescriptionHelpFormatter,
2851        parents=[parent_parser])
2852
2853    code_parser.add_argument(
2854        "--default-max-qty", "--dq", required=False, type=int_or_str, default=3,
2855        help="""Default maximum number of repetitions when no maximum
2856is specified. This is needed to construct complete C types.
2857
2858The default_max_qty can usually be set to a text symbol if desired,
2859to allow it to be configurable when building the code. This is not always
2860possible, as sometimes the value is needed for internal computations.
2861If so, the script will raise an exception.""")
2862    code_parser.add_argument(
2863        "--output-c", "--oc", required=False, type=str,
2864        help="""Path to output C file. If both --decode and --encode are specified, _decode and
2865_encode will be appended to the filename when creating the two files. If not
2866specified, the path and name will be based on the --output-cmake file. A 'src'
2867directory will be created next to the cmake file, and the C file will be
2868placed there with the same name (except the file extension) as the cmake file.""")
2869    code_parser.add_argument(
2870        "--output-h", "--oh", required=False, type=str,
2871        help="""Path to output header file. If both --decode and --encode are specified, _decode and
2872_encode will be appended to the filename when creating the two files. If not
2873specified, the path and name will be based on the --output-cmake file. An 'include'
2874directory will be created next to the cmake file, and the C file will be
2875placed there with the same name (except the file extension) as the cmake file.""")
2876    code_parser.add_argument(
2877        "--output-h-types", "--oht", required=False, type=str,
2878        help="""Path to output header file with typedefs (shared between decode and encode).
2879If not specified, the path and name will be taken from the output header file
2880(--output-h), with '_types' added to the file name.""")
2881    code_parser.add_argument(
2882        "--copy-sources", required=False, action="store_true", default=False,
2883        help="""Copy the non-generated source files (zcbor_*.c/h) into the same directories as the
2884generated files.""")
2885    code_parser.add_argument(
2886        "--output-cmake", required=False, type=str,
2887        help="""Path to output CMake file. The filename of the CMake file without '.cmake' is used
2888as the name of the CMake target in the file.
2889The CMake file defines a CMake target with the zcbor source files and the
2890generated file as sources, and the zcbor header files' and generated header
2891files' folders as include_directories.
2892Add it to your project via include() in your CMakeLists.txt file, and link the
2893target to your program.
2894This option works with or without the --copy-sources option.""")
2895    code_parser.add_argument(
2896        "-t", "--entry-types", required=True, type=str, nargs="+",
2897        help="Names of the types which should have their xcode functions exposed.")
2898    code_parser.add_argument(
2899        "-d", "--decode", required=False, action="store_true", default=False,
2900        help="Generate decoding code. Either --decode or --encode or both must be specified.")
2901    code_parser.add_argument(
2902        "-e", "--encode", required=False, action="store_true", default=False,
2903        help="Generate encoding code. Either --decode or --encode or both must be specified.")
2904    code_parser.add_argument(
2905        "--time-header", required=False, action="store_true", default=False,
2906        help="Put the current time in a comment in the generated files.")
2907    code_parser.add_argument(
2908        "--git-sha-header", required=False, action="store_true", default=False,
2909        help="Put the current git sha of zcbor in a comment in the generated files.")
2910    code_parser.add_argument(
2911        "-b", "--default-bit-size", required=False, type=int, default=32, choices=[32, 64],
2912        help="""Default bit size of integers in code. When integers have no explicit bounds,
2913assume they have this bit width. Should follow the bit width of the architecture
2914the code will be running on.""")
2915    code_parser.add_argument(
2916        "--include-prefix", default="",
2917        help="""When #include'ing generated files, add this path prefix to the filename.""")
2918    code_parser.add_argument(
2919        "-s", "--short-names", required=False, action="store_true", default=False,
2920        help="""Attempt to make most generated struct member names shorter. This might make some
2921names identical which will cause a compile error. If so, tweak the CDDL labels
2922or layout, or disable this option. This might also make enum names different
2923from the corresponding union members.""")
2924    code_parser.add_argument(
2925        "--file-header", required=False, type=str, default="",
2926        help="Header to be included in the comment at the top of generated C files, e.g. copyright."
2927    )
2928    code_parser.set_defaults(process=process_code)
2929
2930    validate_parent_parser = ArgumentParser(add_help=False)
2931    validate_parent_parser.add_argument(
2932        "-i", "--input", required=True, type=str,
2933        help='''Input data file. The option --input-as specifies how to interpret the contents.
2934Use "-" to indicate stdin.''')
2935    validate_parent_parser.add_argument(
2936        "--input-as", required=False, choices=["yaml", "json", "cbor", "cborhex"],
2937        help='''Which format to interpret the input file as.
2938If omitted, the format is inferred from the file name.
2939.yaml, .yml => YAML, .json => JSON, .cborhex => CBOR as hex string, everything else => CBOR''')
2940    validate_parent_parser.add_argument(
2941        "-t", "--entry-type", required=True, type=str,
2942        help='''Name of the type (from the CDDL) to interpret the data as.''')
2943    validate_parent_parser.add_argument(
2944        "--default-max-qty", "--dq", required=False, type=int, default=0xFFFFFFFF,
2945        help="""Default maximum number of repetitions when no maximum is specified.
2946It is only relevant when handling data that will be decoded by generated code.
2947If omitted, a large number will be used.""")
2948    validate_parent_parser.add_argument(
2949        "--yaml-compatibility", required=False, action="store_true", default=False,
2950        help='''Whether to convert CBOR-only values to YAML-compatible ones
2951(when converting from CBOR), or vice versa (when converting to CBOR).
2952
2953When this is enabled, all CBOR data is guaranteed to convert into YAML/JSON.
2954JSON and YAML do not support all data types that CBOR/CDDL supports.
2955bytestrings (BSTR), tags, undefined, and maps with non-text keys need
2956special handling. See the zcbor README for more information.''')
2957
2958    validate_parser = subparsers.add_parser(
2959        "validate", description='''Read CBOR, YAML, or JSON data from file or stdin and validate
2960it against a CDDL schema file.
2961        ''',
2962        parents=[parent_parser, validate_parent_parser])
2963
2964    validate_parser.set_defaults(process=process_validate)
2965
2966    convert_parser = subparsers.add_parser(
2967        "convert", description='''Parse a CDDL file and validate/convert between CBOR and YAML/JSON.
2968The script decodes the CBOR/YAML/JSON data from a file or stdin
2969and verifies that it conforms to the CDDL description.
2970The script fails if the data does not conform.
2971'zcbor validate' can be used if only validate is needed.''',
2972        parents=[parent_parser, validate_parent_parser])
2973
2974    convert_parser.add_argument(
2975        "-o", "--output", required=True, type=str,
2976        help='''Output data file. The option --output-as specifies how to interpret the contents.
2977 Use "-" to indicate stdout.''')
2978    convert_parser.add_argument(
2979        "--output-as", required=False, choices=["yaml", "json", "cbor", "cborhex", "c_code"],
2980        help='''Which format to interpret the output file as.
2981If omitted, the format is inferred from the file name.
2982.yaml, .yml => YAML, .json => JSON, .c, .h => C code,
2983.cborhex => CBOR as hex string, everything else => CBOR''')
2984    convert_parser.add_argument(
2985        "--c-code-var-name", required=False, type=str,
2986        help='''Only relevant together with '--output-as c_code' or .c files.''')
2987    convert_parser.add_argument(
2988        "--c-code-columns", required=False, type=int, default=0,
2989        help='''Only relevant together with '--output-as c_code' or .c files.
2990The number of bytes per line in the variable instantiation. If omitted, the
2991entire declaration is a single line.''')
2992    convert_parser.set_defaults(process=process_convert)
2993
2994    args = parser.parse_args()
2995
2996    if not args.no_prelude:
2997        args.cddl.append(open(PRELUDE_path, 'r'))
2998
2999    if hasattr(args, "decode") and not args.decode and not args.encode:
3000        parser.error("Please specify at least one of --decode or --encode.")
3001
3002    if hasattr(args, "output_c"):
3003        if not args.output_c or not args.output_h:
3004            if not args.output_cmake:
3005                parser.error(
3006                    "Please specify both --output-c and --output-h "
3007                    "unless --output-cmake is specified.")
3008
3009    return args
3010
3011
3012def process_code(args):
3013    modes = list()
3014    if args.decode:
3015        modes.append("decode")
3016    if args.encode:
3017        modes.append("encode")
3018
3019    print("Parsing files: " + ", ".join((c.name for c in args.cddl)))
3020
3021    cddl_contents = linesep.join((c.read() for c in args.cddl))
3022
3023    cddl_res = dict()
3024    for mode in modes:
3025        cddl_res[mode] = CodeGenerator.from_cddl(
3026            mode, cddl_contents, args.default_max_qty, mode, args.entry_types,
3027            args.default_bit_size, short_names=args.short_names)
3028
3029    # Parsing is done, pretty print the result.
3030    verbose_print(args.verbose, "Parsed CDDL types:")
3031    for mode in modes:
3032        verbose_pprint(args.verbose, cddl_res[mode].my_types)
3033
3034    git_sha = ''
3035    if args.git_sha_header:
3036        if "zcbor.py" in sys.argv[0]:
3037            git_args = ['git', 'rev-parse', '--verify', '--short', 'HEAD']
3038            git_sha = Popen(
3039                git_args, cwd=P_REPO_ROOT, stdout=PIPE).communicate()[0].decode('utf-8').strip()
3040        else:
3041            git_sha = __version__
3042
3043    def create_and_open(path, mode='w'):
3044        Path(path).absolute().parent.mkdir(parents=True, exist_ok=True)
3045        return Path(path).open(mode)
3046
3047    if args.output_cmake:
3048        cmake_dir = Path(args.output_cmake).parent
3049        output_cmake = create_and_open(args.output_cmake)
3050        filenames = Path(args.output_cmake).parts[-1].replace(".cmake", "")
3051    else:
3052        output_cmake = None
3053
3054    def add_mode_to_fname(filename, mode):
3055        name = Path(filename).stem + "_" + mode + Path(filename).suffix
3056        return Path(filename).with_name(name)
3057
3058    output_c = dict()
3059    output_h = dict()
3060    out_c = args.output_c if (len(modes) == 1 and args.output_c) else None
3061    out_h = args.output_h if (len(modes) == 1 and args.output_h) else None
3062    for mode in modes:
3063        output_c[mode] = create_and_open(
3064            out_c or add_mode_to_fname(
3065                args.output_c or Path(cmake_dir, 'src', f'{filenames}.c'), mode))
3066        output_h[mode] = create_and_open(
3067            out_h or add_mode_to_fname(
3068                args.output_h or Path(cmake_dir, 'include', f'{filenames}.h'), mode))
3069
3070    out_c_parent = Path(output_c[modes[0]].name).parent
3071    out_h_parent = Path(output_h[modes[0]].name).parent
3072
3073    output_h_types = create_and_open(
3074        args.output_h_types
3075        or (args.output_h and Path(args.output_h).with_name(Path(args.output_h).stem + "_types.h"))
3076        or Path(cmake_dir, 'include', filenames + '_types.h'))
3077
3078    renderer = CodeRenderer(entry_types={mode: [cddl_res[mode].my_types[entry]
3079                                         for entry in args.entry_types] for mode in modes},
3080                            modes=modes, print_time=args.time_header,
3081                            default_max_qty=args.default_max_qty, git_sha=git_sha,
3082                            file_header=args.file_header
3083                            )
3084
3085    c_code_dir = Path(c_code_root, "src")
3086    h_code_dir = Path(c_code_root, "include")
3087
3088    if args.copy_sources:
3089        new_c_code_dir = out_c_parent
3090        new_h_code_dir = out_h_parent
3091        copyfile(Path(c_code_dir, "zcbor_decode.c"), Path(new_c_code_dir, "zcbor_decode.c"))
3092        copyfile(Path(c_code_dir, "zcbor_encode.c"), Path(new_c_code_dir, "zcbor_encode.c"))
3093        copyfile(Path(c_code_dir, "zcbor_common.c"), Path(new_c_code_dir, "zcbor_common.c"))
3094        copyfile(Path(h_code_dir, "zcbor_decode.h"), Path(new_h_code_dir, "zcbor_decode.h"))
3095        copyfile(Path(h_code_dir, "zcbor_encode.h"), Path(new_h_code_dir, "zcbor_encode.h"))
3096        copyfile(Path(h_code_dir, "zcbor_common.h"), Path(new_h_code_dir, "zcbor_common.h"))
3097        copyfile(Path(h_code_dir, "zcbor_tags.h"), Path(new_h_code_dir, "zcbor_tags.h"))
3098        copyfile(Path(h_code_dir, "zcbor_debug.h"), Path(new_h_code_dir, "zcbor_debug.h"))
3099        c_code_dir = new_c_code_dir
3100        h_code_dir = new_h_code_dir
3101
3102    renderer.render(modes, output_h, output_c, output_h_types, args.include_prefix,
3103                    output_cmake, c_code_dir, h_code_dir)
3104
3105
3106def parse_cddl(args):
3107    cddl_contents = linesep.join((c.read() for c in args.cddl))
3108    cddl_res = DataTranslator.from_cddl(cddl_contents, args.default_max_qty)
3109    return cddl_res.my_types[args.entry_type]
3110
3111
3112def read_data(args, cddl):
3113    _, in_file_ext = path.splitext(args.input)
3114    in_file_format = args.input_as or in_file_ext.strip(".")
3115    if in_file_format in ["yaml", "yml"]:
3116        f = sys.stdin if args.input == "-" else open(args.input, "r")
3117        cbor_str = cddl.from_yaml(f.read(), yaml_compat=args.yaml_compatibility)
3118    elif in_file_format == "json":
3119        f = sys.stdin if args.input == "-" else open(args.input, "r")
3120        cbor_str = cddl.from_json(f.read(), yaml_compat=args.yaml_compatibility)
3121    elif in_file_format == "cborhex":
3122        f = sys.stdin if args.input == "-" else open(args.input, "r")
3123        cbor_str = bytes.fromhex(f.read().replace("\n", ""))
3124        cddl.validate_str(cbor_str)
3125    else:
3126        f = sys.stdin.buffer if args.input == "-" else open(args.input, "rb")
3127        cbor_str = f.read()
3128        cddl.validate_str(cbor_str)
3129
3130    return cbor_str
3131
3132
3133def write_data(args, cddl, cbor_str):
3134    _, out_file_ext = path.splitext(args.output)
3135    out_file_format = args.output_as or out_file_ext.strip(".")
3136    if out_file_format in ["yaml", "yml"]:
3137        f = sys.stdout if args.output == "-" else open(args.output, "w")
3138        f.write(cddl.str_to_yaml(cbor_str, yaml_compat=args.yaml_compatibility))
3139    elif out_file_format == "json":
3140        f = sys.stdout if args.output == "-" else open(args.output, "w")
3141        f.write(cddl.str_to_json(cbor_str, yaml_compat=args.yaml_compatibility))
3142    elif out_file_format in ["c", "h", "c_code"]:
3143        f = sys.stdout if args.output == "-" else open(args.output, "w")
3144        assert args.c_code_var_name is not None, \
3145            "Must specify --c-code-var-name when outputting c code."
3146        f.write(cddl.str_to_c_code(cbor_str, args.c_code_var_name, args.c_code_columns))
3147    elif out_file_format == "cborhex":
3148        f = sys.stdout if args.output == "-" else open(args.output, "w")
3149        f.write(sub(r"(.{1,64})", r"\1 ", cbor_str.hex()))  # Add newlines every 64 chars
3150    else:
3151        f = sys.stdout.buffer if args.output == "-" else open(args.output, "wb")
3152        f.write(cbor_str)
3153
3154
3155def process_validate(args):
3156    cddl = parse_cddl(args)
3157    read_data(args, cddl)
3158
3159
3160def process_convert(args):
3161    cddl = parse_cddl(args)
3162    cbor_str = read_data(args, cddl)
3163    write_data(args, cddl, cbor_str)
3164
3165
3166def main():
3167    args = parse_args()
3168    args.process(args)
3169
3170
3171if __name__ == "__main__":
3172    main()
3173