1#!/usr/bin/env python3
2# kate: replace-tabs on; indent-width 4;
3
4from __future__ import unicode_literals
5
6'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
7nanopb_version = "nanopb-0.4.7"
8
9import sys
10import re
11import codecs
12import copy
13import itertools
14import tempfile
15import shutil
16import shlex
17import os
18from functools import reduce
19
20# Python-protobuf breaks easily with protoc version differences if
21# using the cpp or upb implementation. Force it to use pure Python
22# implementation. Performance is not very important in the generator.
23if not os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"):
24    os.putenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
25    os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
26
27try:
28    # Make sure grpc_tools gets included in binary package if it is available
29    import grpc_tools.protoc
30except:
31    pass
32
33try:
34    import google.protobuf.text_format as text_format
35    import google.protobuf.descriptor_pb2 as descriptor
36    import google.protobuf.compiler.plugin_pb2 as plugin_pb2
37    import google.protobuf.reflection as reflection
38    import google.protobuf.descriptor
39except:
40    sys.stderr.write('''
41         **********************************************************************
42         *** Could not import the Google protobuf Python libraries          ***
43         ***                                                                ***
44         *** Easiest solution is often to install the dependencies via pip: ***
45         ***    pip install protobuf grpcio-tools                           ***
46         **********************************************************************
47    ''' + '\n')
48    raise
49
50# Depending on how this script is run, we may or may not have PEP366 package name
51# available for relative imports.
52if not __package__:
53    import proto
54    from proto._utils import invoke_protoc
55    from proto import TemporaryDirectory
56else:
57    from . import proto
58    from .proto._utils import invoke_protoc
59    from .proto import TemporaryDirectory
60
61if getattr(sys, 'frozen', False):
62    # Binary package, just import the file
63    from proto import nanopb_pb2
64else:
65    # Try to rebuild nanopb_pb2.py if necessary
66    nanopb_pb2 = proto.load_nanopb_pb2()
67
68try:
69    # Add some dummy imports to keep packaging tools happy.
70    import google # bbfreeze seems to need these
71    import pkg_resources # pyinstaller / protobuf 2.5 seem to need these
72    from proto import nanopb_pb2 # pyinstaller seems to need this
73    import pkg_resources.py2_warn
74except:
75    # Don't care, we will error out later if it is actually important.
76    pass
77
78# ---------------------------------------------------------------------------
79#                     Generation of single fields
80# ---------------------------------------------------------------------------
81
82import time
83import os.path
84
85# Values are tuple (c type, pb type, encoded size, data_size)
86FieldD = descriptor.FieldDescriptorProto
87datatypes = {
88    FieldD.TYPE_BOOL:       ('bool',     'BOOL',        1,  4),
89    FieldD.TYPE_DOUBLE:     ('double',   'DOUBLE',      8,  8),
90    FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4,  4),
91    FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8,  8),
92    FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4,  4),
93    FieldD.TYPE_INT32:      ('int32_t',  'INT32',      10,  4),
94    FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10,  8),
95    FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4,  4),
96    FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8,  8),
97    FieldD.TYPE_SINT32:     ('int32_t',  'SINT32',      5,  4),
98    FieldD.TYPE_SINT64:     ('int64_t',  'SINT64',     10,  8),
99    FieldD.TYPE_UINT32:     ('uint32_t', 'UINT32',      5,  4),
100    FieldD.TYPE_UINT64:     ('uint64_t', 'UINT64',     10,  8),
101
102    # Integer size override options
103    (FieldD.TYPE_INT32,   nanopb_pb2.IS_8):   ('int8_t',   'INT32', 10,  1),
104    (FieldD.TYPE_INT32,  nanopb_pb2.IS_16):   ('int16_t',  'INT32', 10,  2),
105    (FieldD.TYPE_INT32,  nanopb_pb2.IS_32):   ('int32_t',  'INT32', 10,  4),
106    (FieldD.TYPE_INT32,  nanopb_pb2.IS_64):   ('int64_t',  'INT32', 10,  8),
107    (FieldD.TYPE_SINT32,  nanopb_pb2.IS_8):   ('int8_t',  'SINT32',  2,  1),
108    (FieldD.TYPE_SINT32, nanopb_pb2.IS_16):   ('int16_t', 'SINT32',  3,  2),
109    (FieldD.TYPE_SINT32, nanopb_pb2.IS_32):   ('int32_t', 'SINT32',  5,  4),
110    (FieldD.TYPE_SINT32, nanopb_pb2.IS_64):   ('int64_t', 'SINT32', 10,  8),
111    (FieldD.TYPE_UINT32,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT32',  2,  1),
112    (FieldD.TYPE_UINT32, nanopb_pb2.IS_16):   ('uint16_t','UINT32',  3,  2),
113    (FieldD.TYPE_UINT32, nanopb_pb2.IS_32):   ('uint32_t','UINT32',  5,  4),
114    (FieldD.TYPE_UINT32, nanopb_pb2.IS_64):   ('uint64_t','UINT32', 10,  8),
115    (FieldD.TYPE_INT64,   nanopb_pb2.IS_8):   ('int8_t',   'INT64', 10,  1),
116    (FieldD.TYPE_INT64,  nanopb_pb2.IS_16):   ('int16_t',  'INT64', 10,  2),
117    (FieldD.TYPE_INT64,  nanopb_pb2.IS_32):   ('int32_t',  'INT64', 10,  4),
118    (FieldD.TYPE_INT64,  nanopb_pb2.IS_64):   ('int64_t',  'INT64', 10,  8),
119    (FieldD.TYPE_SINT64,  nanopb_pb2.IS_8):   ('int8_t',  'SINT64',  2,  1),
120    (FieldD.TYPE_SINT64, nanopb_pb2.IS_16):   ('int16_t', 'SINT64',  3,  2),
121    (FieldD.TYPE_SINT64, nanopb_pb2.IS_32):   ('int32_t', 'SINT64',  5,  4),
122    (FieldD.TYPE_SINT64, nanopb_pb2.IS_64):   ('int64_t', 'SINT64', 10,  8),
123    (FieldD.TYPE_UINT64,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT64',  2,  1),
124    (FieldD.TYPE_UINT64, nanopb_pb2.IS_16):   ('uint16_t','UINT64',  3,  2),
125    (FieldD.TYPE_UINT64, nanopb_pb2.IS_32):   ('uint32_t','UINT64',  5,  4),
126    (FieldD.TYPE_UINT64, nanopb_pb2.IS_64):   ('uint64_t','UINT64', 10,  8),
127}
128
129class NamingStyle:
130    def enum_name(self, name):
131        return "_%s" % (name)
132
133    def struct_name(self, name):
134        return "_%s" % (name)
135
136    def type_name(self, name):
137        return "%s" % (name)
138
139    def define_name(self, name):
140        return "%s" % (name)
141
142    def var_name(self, name):
143        return "%s" % (name)
144
145    def enum_entry(self, name):
146        return "%s" % (name)
147
148    def func_name(self, name):
149        return "%s" % (name)
150
151    def bytes_type(self, struct_name, name):
152        return "%s_%s_t" % (struct_name, name)
153
154class NamingStyleC(NamingStyle):
155    def enum_name(self, name):
156        return self.underscore(name)
157
158    def struct_name(self, name):
159        return self.underscore(name)
160
161    def type_name(self, name):
162        return "%s_t" % self.underscore(name)
163
164    def define_name(self, name):
165        return self.underscore(name).upper()
166
167    def var_name(self, name):
168        return self.underscore(name)
169
170    def enum_entry(self, name):
171        return self.underscore(name).upper()
172
173    def func_name(self, name):
174        return self.underscore(name)
175
176    def bytes_type(self, struct_name, name):
177        return "%s_%s_t" % (self.underscore(struct_name), self.underscore(name))
178
179    def underscore(self, word):
180        word = str(word)
181        word = re.sub(r"([A-Z]+)([A-Z][a-z])", r'\1_\2', word)
182        word = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', word)
183        word = word.replace("-", "_")
184        return word.lower()
185
186class Globals:
187    '''Ugly global variables, should find a good way to pass these.'''
188    verbose_options = False
189    separate_options = []
190    matched_namemasks = set()
191    protoc_insertion_points = False
192    naming_style = NamingStyle()
193
194# String types and file encoding for Python2 UTF-8 support
195if sys.version_info.major == 2:
196    import codecs
197    open = codecs.open
198    strtypes = (unicode, str)
199
200    def str(x):
201        try:
202            return strtypes[1](x)
203        except UnicodeEncodeError:
204            return strtypes[0](x)
205else:
206    strtypes = (str, )
207
208
209class Names:
210    '''Keeps a set of nested names and formats them to C identifier.'''
211    def __init__(self, parts = ()):
212        if isinstance(parts, Names):
213            parts = parts.parts
214        elif isinstance(parts, strtypes):
215            parts = (parts,)
216        self.parts = tuple(parts)
217
218        if self.parts == ('',):
219            self.parts = ()
220
221    def __str__(self):
222        return '_'.join(self.parts)
223
224    def __repr__(self):
225        return 'Names(%s)' % ','.join("'%s'" % x for x in self.parts)
226
227    def __add__(self, other):
228        if isinstance(other, strtypes):
229            return Names(self.parts + (other,))
230        elif isinstance(other, Names):
231            return Names(self.parts + other.parts)
232        elif isinstance(other, tuple):
233            return Names(self.parts + other)
234        else:
235            raise ValueError("Name parts should be of type str")
236
237    def __eq__(self, other):
238        return isinstance(other, Names) and self.parts == other.parts
239
240    def __lt__(self, other):
241        if not isinstance(other, Names):
242            return NotImplemented
243        return str(self) < str(other)
244
245def names_from_type_name(type_name):
246    '''Parse Names() from FieldDescriptorProto type_name'''
247    if type_name[0] != '.':
248        raise NotImplementedError("Lookup of non-absolute type names is not supported")
249    return Names(type_name[1:].split('.'))
250
251def varint_max_size(max_value):
252    '''Returns the maximum number of bytes a varint can take when encoded.'''
253    if max_value < 0:
254        max_value = 2**64 - max_value
255    for i in range(1, 11):
256        if (max_value >> (i * 7)) == 0:
257            return i
258    raise ValueError("Value too large for varint: " + str(max_value))
259
260assert varint_max_size(-1) == 10
261assert varint_max_size(0) == 1
262assert varint_max_size(127) == 1
263assert varint_max_size(128) == 2
264
265class EncodedSize:
266    '''Class used to represent the encoded size of a field or a message.
267    Consists of a combination of symbolic sizes and integer sizes.'''
268    def __init__(self, value = 0, symbols = [], declarations = [], required_defines = []):
269        if isinstance(value, EncodedSize):
270            self.value = value.value
271            self.symbols = value.symbols
272            self.declarations = value.declarations
273            self.required_defines = value.required_defines
274        elif isinstance(value, strtypes + (Names,)):
275            self.symbols = [str(value)]
276            self.value = 0
277            self.declarations = []
278            self.required_defines = [str(value)]
279        else:
280            self.value = value
281            self.symbols = symbols
282            self.declarations = declarations
283            self.required_defines = required_defines
284
285    def __add__(self, other):
286        if isinstance(other, int):
287            return EncodedSize(self.value + other, self.symbols, self.declarations, self.required_defines)
288        elif isinstance(other, strtypes + (Names,)):
289            return EncodedSize(self.value, self.symbols + [str(other)], self.declarations, self.required_defines + [str(other)])
290        elif isinstance(other, EncodedSize):
291            return EncodedSize(self.value + other.value, self.symbols + other.symbols,
292                               self.declarations + other.declarations, self.required_defines + other.required_defines)
293        else:
294            raise ValueError("Cannot add size: " + repr(other))
295
296    def __mul__(self, other):
297        if isinstance(other, int):
298            return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols],
299                               self.declarations, self.required_defines)
300        else:
301            raise ValueError("Cannot multiply size: " + repr(other))
302
303    def __str__(self):
304        if not self.symbols:
305            return str(self.value)
306        else:
307            return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
308
309    def __repr__(self):
310        return 'EncodedSize(%s, %s, %s, %s)' % (self.value, self.symbols, self.declarations, self.required_defines)
311
312    def get_declarations(self):
313        '''Get any declarations that must appear alongside this encoded size definition,
314        such as helper union {} types.'''
315        return '\n'.join(self.declarations)
316
317    def get_cpp_guard(self, local_defines):
318        '''Get an #if preprocessor statement listing all defines that are required for this definition.'''
319        needed = [x for x in self.required_defines if x not in local_defines]
320        if needed:
321            return '#if ' + ' && '.join(['defined(%s)' % x for x in needed]) + "\n"
322        else:
323            return ''
324
325    def upperlimit(self):
326        if not self.symbols:
327            return self.value
328        else:
329            return 2**32 - 1
330
331class ProtoElement(object):
332    # Constants regarding path of proto elements in file descriptor.
333    # They are used to connect proto elements with source code information (comments)
334    # These values come from:
335    # https://github.com/google/protobuf/blob/master/src/google/protobuf/descriptor.proto
336    FIELD = 2
337    MESSAGE = 4
338    ENUM = 5
339    NESTED_TYPE = 3
340    NESTED_ENUM = 4
341
342    def __init__(self, path, comments = None):
343        '''
344        path is a tuple containing integers (type, index, ...)
345        comments is a dictionary mapping between element path & SourceCodeInfo.Location
346            (contains information about source comments).
347        '''
348        assert(isinstance(path, tuple))
349        self.element_path = path
350        self.comments = comments or {}
351
352    def get_member_comments(self, index):
353        '''Get comments for a member of enum or message.'''
354        return self.get_comments((ProtoElement.FIELD, index), leading_indent = True)
355
356    def format_comment(self, comment):
357        '''Put comment inside /* */ and sanitize comment contents'''
358        comment = comment.strip()
359        comment = comment.replace('/*', '/ *')
360        comment = comment.replace('*/', '* /')
361        return "/* %s */" % comment
362
363    def get_comments(self, member_path = (), leading_indent = False):
364        '''Get leading & trailing comments for a protobuf element.
365
366        member_path is the proto path of an element or member (ex. [5 0] or [4 1 2 0])
367        leading_indent is a flag that indicates if leading comments should be indented
368        '''
369
370        # Obtain SourceCodeInfo.Location object containing comment
371        # information (based on the member path)
372        path = self.element_path + member_path
373        comment = self.comments.get(path)
374
375        leading_comment = ""
376        trailing_comment = ""
377
378        if not comment:
379            return leading_comment, trailing_comment
380
381        if comment.leading_comments:
382            leading_comment = "    " if leading_indent else ""
383            leading_comment += self.format_comment(comment.leading_comments)
384
385        if comment.trailing_comments:
386            trailing_comment = self.format_comment(comment.trailing_comments)
387
388        return leading_comment, trailing_comment
389
390
391class Enum(ProtoElement):
392    def __init__(self, names, desc, enum_options, element_path, comments):
393        '''
394        desc is EnumDescriptorProto
395        index is the index of this enum element inside the file
396        comments is a dictionary mapping between element path & SourceCodeInfo.Location
397            (contains information about source comments)
398        '''
399        super(Enum, self).__init__(element_path, comments)
400
401        self.options = enum_options
402        self.names = names
403
404        # by definition, `names` include this enum's name
405        base_name = Names(names.parts[:-1])
406
407        if enum_options.long_names:
408            self.values = [(names + x.name, x.number) for x in desc.value]
409        else:
410            self.values = [(base_name + x.name, x.number) for x in desc.value]
411
412        self.value_longnames = [self.names + x.name for x in desc.value]
413        self.packed = enum_options.packed_enum
414
415    def has_negative(self):
416        for n, v in self.values:
417            if v < 0:
418                return True
419        return False
420
421    def encoded_size(self):
422        return max([varint_max_size(v) for n,v in self.values])
423
424    def __repr__(self):
425        return 'Enum(%s)' % self.names
426
427    def __str__(self):
428        leading_comment, trailing_comment = self.get_comments()
429
430        result = ''
431        if leading_comment:
432            result = '%s\n' % leading_comment
433
434        result += 'typedef enum %s {' % Globals.naming_style.enum_name(self.names)
435        if trailing_comment:
436            result += " " + trailing_comment
437
438        result += "\n"
439
440        enum_length = len(self.values)
441        enum_values = []
442        for index, (name, value) in enumerate(self.values):
443            leading_comment, trailing_comment = self.get_member_comments(index)
444
445            if leading_comment:
446                enum_values.append(leading_comment)
447
448            comma = ","
449            if index == enum_length - 1:
450                # last enum member should not end with a comma
451                comma = ""
452
453            enum_value = "    %s = %d%s" % (Globals.naming_style.enum_entry(name), value, comma)
454            if trailing_comment:
455                enum_value += " " + trailing_comment
456
457            enum_values.append(enum_value)
458
459        result += '\n'.join(enum_values)
460        result += '\n}'
461
462        if self.packed:
463            result += ' pb_packed'
464
465        result += ' %s;' % Globals.naming_style.type_name(self.names)
466        return result
467
468    def auxiliary_defines(self):
469        # sort the enum by value
470        sorted_values = sorted(self.values, key = lambda x: (x[1], x[0]))
471        result  = '#define %s %s\n' % (
472            Globals.naming_style.define_name('_%s_MIN' % self.names),
473            Globals.naming_style.enum_entry(sorted_values[0][0]))
474        result += '#define %s %s\n' % (
475            Globals.naming_style.define_name('_%s_MAX' % self.names),
476            Globals.naming_style.enum_entry(sorted_values[-1][0]))
477        result += '#define %s ((%s)(%s+1))\n' % (
478            Globals.naming_style.define_name('_%s_ARRAYSIZE' % self.names),
479            Globals.naming_style.type_name(self.names),
480            Globals.naming_style.enum_entry(sorted_values[-1][0]))
481
482        if not self.options.long_names:
483            # Define the long names always so that enum value references
484            # from other files work properly.
485            for i, x in enumerate(self.values):
486                result += '#define %s %s\n' % (self.value_longnames[i], x[0])
487
488        if self.options.enum_to_string:
489            result += 'const char *%s(%s v);\n' % (
490                Globals.naming_style.func_name('%s_name' % self.names),
491                Globals.naming_style.type_name(self.names))
492
493        return result
494
495    def enum_to_string_definition(self):
496        if not self.options.enum_to_string:
497            return ""
498
499        result = 'const char *%s(%s v) {\n' % (
500            Globals.naming_style.func_name('%s_name' % self.names),
501            Globals.naming_style.type_name(self.names))
502
503        result += '    switch (v) {\n'
504
505        for ((enumname, _), strname) in zip(self.values, self.value_longnames):
506            # Strip off the leading type name from the string value.
507            strval = str(strname)[len(str(self.names)) + 1:]
508            result += '        case %s: return "%s";\n' % (
509                Globals.naming_style.enum_entry(enumname),
510                Globals.naming_style.enum_entry(strval))
511
512        result += '    }\n'
513        result += '    return "unknown";\n'
514        result += '}\n'
515
516        return result
517
518class FieldMaxSize:
519    def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
520        if isinstance(worst, list):
521            self.worst = max(i for i in worst if i is not None)
522        else:
523            self.worst = worst
524
525        self.worst_field = field_name
526        self.checks = list(checks)
527
528    def extend(self, extend, field_name = None):
529        self.worst = max(self.worst, extend.worst)
530
531        if self.worst == extend.worst:
532            self.worst_field = extend.worst_field
533
534        self.checks.extend(extend.checks)
535
536class Field(ProtoElement):
537    macro_x_param = 'X'
538    macro_a_param = 'a'
539
540    def __init__(self, struct_name, desc, field_options, element_path = (), comments = None):
541        '''desc is FieldDescriptorProto'''
542        ProtoElement.__init__(self, element_path, comments)
543        self.tag = desc.number
544        self.struct_name = struct_name
545        self.union_name = None
546        self.name = desc.name
547        self.default = None
548        self.max_size = None
549        self.max_count = None
550        self.array_decl = ""
551        self.enc_size = None
552        self.data_item_size = None
553        self.ctype = None
554        self.fixed_count = False
555        self.callback_datatype = field_options.callback_datatype
556        self.math_include_required = False
557        self.sort_by_tag = field_options.sort_by_tag
558
559        if field_options.type == nanopb_pb2.FT_INLINE:
560            # Before nanopb-0.3.8, fixed length bytes arrays were specified
561            # by setting type to FT_INLINE. But to handle pointer typed fields,
562            # it makes sense to have it as a separate option.
563            field_options.type = nanopb_pb2.FT_STATIC
564            field_options.fixed_length = True
565
566        # Parse field options
567        if field_options.HasField("max_size"):
568            self.max_size = field_options.max_size
569
570        self.default_has = field_options.default_has
571
572        if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"):
573            # max_length overrides max_size for strings
574            self.max_size = field_options.max_length + 1
575
576        if field_options.HasField("max_count"):
577            self.max_count = field_options.max_count
578
579        if desc.HasField('default_value'):
580            self.default = desc.default_value
581
582        # Check field rules, i.e. required/optional/repeated.
583        can_be_static = True
584        if desc.label == FieldD.LABEL_REPEATED:
585            self.rules = 'REPEATED'
586            if self.max_count is None:
587                can_be_static = False
588            else:
589                self.array_decl = '[%d]' % self.max_count
590                if field_options.fixed_count:
591                  self.rules = 'FIXARRAY'
592
593        elif field_options.proto3:
594            if desc.type == FieldD.TYPE_MESSAGE and not field_options.proto3_singular_msgs:
595                # In most other protobuf libraries proto3 submessages have
596                # "null" status. For nanopb, that is implemented as has_ field.
597                self.rules = 'OPTIONAL'
598            elif hasattr(desc, "proto3_optional") and desc.proto3_optional:
599                # Protobuf 3.12 introduced optional fields for proto3 syntax
600                self.rules = 'OPTIONAL'
601            else:
602                # Proto3 singular fields (without has_field)
603                self.rules = 'SINGULAR'
604        elif desc.label == FieldD.LABEL_REQUIRED:
605            self.rules = 'REQUIRED'
606        elif desc.label == FieldD.LABEL_OPTIONAL:
607            self.rules = 'OPTIONAL'
608        else:
609            raise NotImplementedError(desc.label)
610
611        # Check if the field can be implemented with static allocation
612        # i.e. whether the data size is known.
613        if desc.type == FieldD.TYPE_STRING and self.max_size is None:
614            can_be_static = False
615
616        if desc.type == FieldD.TYPE_BYTES and self.max_size is None:
617            can_be_static = False
618
619        # Decide how the field data will be allocated
620        if field_options.type == nanopb_pb2.FT_DEFAULT:
621            if can_be_static:
622                field_options.type = nanopb_pb2.FT_STATIC
623            else:
624                field_options.type = field_options.fallback_type
625
626        if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
627            raise Exception("Field '%s' is defined as static, but max_size or "
628                            "max_count is not given." % self.name)
629
630        if field_options.fixed_count and self.max_count is None:
631            raise Exception("Field '%s' is defined as fixed count, "
632                            "but max_count is not given." % self.name)
633
634        if field_options.type == nanopb_pb2.FT_STATIC:
635            self.allocation = 'STATIC'
636        elif field_options.type == nanopb_pb2.FT_POINTER:
637            self.allocation = 'POINTER'
638        elif field_options.type == nanopb_pb2.FT_CALLBACK:
639            self.allocation = 'CALLBACK'
640        else:
641            raise NotImplementedError(field_options.type)
642
643        if field_options.HasField("type_override"):
644            desc.type = field_options.type_override
645
646        # Decide the C data type to use in the struct.
647        if desc.type in datatypes:
648            self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[desc.type]
649
650            # Override the field size if user wants to use smaller integers
651            if (desc.type, field_options.int_size) in datatypes:
652                self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[(desc.type, field_options.int_size)]
653        elif desc.type == FieldD.TYPE_ENUM:
654            self.pbtype = 'ENUM'
655            self.data_item_size = 4
656            self.ctype = names_from_type_name(desc.type_name)
657            if self.default is not None:
658                self.default = self.ctype + self.default
659            self.enc_size = None # Needs to be filled in when enum values are known
660        elif desc.type == FieldD.TYPE_STRING:
661            self.pbtype = 'STRING'
662            self.ctype = 'char'
663            if self.allocation == 'STATIC':
664                self.ctype = 'char'
665                self.array_decl += '[%d]' % self.max_size
666                # -1 because of null terminator. Both pb_encode and pb_decode
667                # check the presence of it.
668                self.enc_size = varint_max_size(self.max_size) + self.max_size - 1
669        elif desc.type == FieldD.TYPE_BYTES:
670            if field_options.fixed_length:
671                self.pbtype = 'FIXED_LENGTH_BYTES'
672
673                if self.max_size is None:
674                    raise Exception("Field '%s' is defined as fixed length, "
675                                    "but max_size is not given." % self.name)
676
677                self.enc_size = varint_max_size(self.max_size) + self.max_size
678                self.ctype = 'pb_byte_t'
679                self.array_decl += '[%d]' % self.max_size
680            else:
681                self.pbtype = 'BYTES'
682                self.ctype = 'pb_bytes_array_t'
683                if self.allocation == 'STATIC':
684                    self.ctype = Globals.naming_style.bytes_type(self.struct_name, self.name)
685                    self.enc_size = varint_max_size(self.max_size) + self.max_size
686        elif desc.type == FieldD.TYPE_MESSAGE:
687            self.pbtype = 'MESSAGE'
688            self.ctype = self.submsgname = names_from_type_name(desc.type_name)
689            self.enc_size = None # Needs to be filled in after the message type is available
690            if field_options.submsg_callback and self.allocation == 'STATIC':
691                self.pbtype = 'MSG_W_CB'
692        else:
693            raise NotImplementedError(desc.type)
694
695        if self.default and self.pbtype in ['FLOAT', 'DOUBLE']:
696            if 'inf' in self.default or 'nan' in self.default:
697                self.math_include_required = True
698
699    def __lt__(self, other):
700        return self.tag < other.tag
701
702    def __repr__(self):
703        return 'Field(%s)' % self.name
704
705    def __str__(self):
706        result = ''
707
708        var_name = Globals.naming_style.var_name(self.name)
709        type_name = Globals.naming_style.type_name(self.ctype) if isinstance(self.ctype, Names) else self.ctype
710
711        if self.allocation == 'POINTER':
712            if self.rules == 'REPEATED':
713                if self.pbtype == 'MSG_W_CB':
714                    result += '    pb_callback_t cb_' + var_name + ';\n'
715                result += '    pb_size_t ' + var_name + '_count;\n'
716
717            if self.rules == 'FIXARRAY' and self.pbtype in ['STRING', 'BYTES']:
718                # Pointer to fixed size array of pointers
719                result += '    %s* (*%s)%s;' % (type_name, var_name, self.array_decl)
720            elif self.pbtype == 'FIXED_LENGTH_BYTES' or self.rules == 'FIXARRAY':
721                # Pointer to fixed size array of items
722                result += '    %s (*%s)%s;' % (type_name, var_name, self.array_decl)
723            elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']:
724                # String/bytes arrays need to be defined as pointers to pointers
725                result += '    %s **%s;' % (type_name, var_name)
726            elif self.pbtype in ['MESSAGE', 'MSG_W_CB']:
727                # Use struct definition, so recursive submessages are possible
728                result += '    struct %s *%s;' % (Globals.naming_style.struct_name(self.ctype), var_name)
729            else:
730                # Normal case, just a pointer to single item
731                result += '    %s *%s;' % (type_name, var_name)
732        elif self.allocation == 'CALLBACK':
733            result += '    %s %s;' % (self.callback_datatype, var_name)
734        else:
735            if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']:
736                result += '    pb_callback_t cb_' + var_name + ';\n'
737
738            if self.rules == 'OPTIONAL':
739                result += '    bool has_' + var_name + ';\n'
740            elif self.rules == 'REPEATED':
741                result += '    pb_size_t ' + var_name + '_count;\n'
742
743            result += '    %s %s%s;' % (type_name, var_name, self.array_decl)
744
745        leading_comment, trailing_comment = self.get_comments(leading_indent = True)
746        if leading_comment: result = leading_comment + "\n" + result
747        if trailing_comment: result = result + " " + trailing_comment
748
749        return result
750
751    def types(self):
752        '''Return definitions for any special types this field might need.'''
753        if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
754            result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, Globals.naming_style.var_name(self.ctype))
755        else:
756            result = ''
757        return result
758
759    def get_dependencies(self):
760        '''Get list of type names used by this field.'''
761        if self.allocation == 'STATIC':
762            return [str(self.ctype)]
763        elif self.allocation == 'POINTER' and self.rules == 'FIXARRAY':
764            return [str(self.ctype)]
765        else:
766            return []
767
768    def get_initializer(self, null_init, inner_init_only = False):
769        '''Return literal expression for this field's default value.
770        null_init: If True, initialize to a 0 value instead of default from .proto
771        inner_init_only: If True, exclude initialization for any count/has fields
772        '''
773
774        inner_init = None
775        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
776            if null_init:
777                inner_init = Globals.naming_style.define_name('%s_init_zero' % self.ctype)
778            else:
779                inner_init =  Globals.naming_style.define_name('%s_init_default' % self.ctype)
780        elif self.default is None or null_init:
781            if self.pbtype == 'STRING':
782                inner_init = '""'
783            elif self.pbtype == 'BYTES':
784                inner_init = '{0, {0}}'
785            elif self.pbtype == 'FIXED_LENGTH_BYTES':
786                inner_init = '{0}'
787            elif self.pbtype in ('ENUM', 'UENUM'):
788                inner_init = '_%s_MIN' % Globals.naming_style.define_name(self.ctype)
789            else:
790                inner_init = '0'
791        else:
792            if self.pbtype == 'STRING':
793                data = codecs.escape_encode(self.default.encode('utf-8'))[0]
794                inner_init = '"' + data.decode('ascii') + '"'
795            elif self.pbtype == 'BYTES':
796                data = codecs.escape_decode(self.default)[0]
797                data = ["0x%02x" % c for c in bytearray(data)]
798                if len(data) == 0:
799                    inner_init = '{0, {0}}'
800                else:
801                    inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
802            elif self.pbtype == 'FIXED_LENGTH_BYTES':
803                data = codecs.escape_decode(self.default)[0]
804                data = ["0x%02x" % c for c in bytearray(data)]
805                if len(data) == 0:
806                    inner_init = '{0}'
807                else:
808                    inner_init = '{%s}' % ','.join(data)
809            elif self.pbtype in ['FIXED32', 'UINT32']:
810                inner_init = str(self.default) + 'u'
811            elif self.pbtype in ['FIXED64', 'UINT64']:
812                inner_init = str(self.default) + 'ull'
813            elif self.pbtype in ['SFIXED64', 'INT64']:
814                inner_init = str(self.default) + 'll'
815            elif self.pbtype in ['FLOAT', 'DOUBLE']:
816                inner_init = str(self.default)
817                if 'inf' in inner_init:
818                    inner_init = inner_init.replace('inf', 'INFINITY')
819                elif 'nan' in inner_init:
820                    inner_init = inner_init.replace('nan', 'NAN')
821                elif (not '.' in inner_init) and self.pbtype == 'FLOAT':
822                    inner_init += '.0f'
823                elif self.pbtype == 'FLOAT':
824                    inner_init += 'f'
825            else:
826                inner_init = str(self.default)
827
828        if inner_init_only:
829            return inner_init
830
831        outer_init = None
832        if self.allocation == 'STATIC':
833            if self.rules == 'REPEATED':
834                outer_init = '0, {' + ', '.join([inner_init] * self.max_count) + '}'
835            elif self.rules == 'FIXARRAY':
836                outer_init = '{' + ', '.join([inner_init] * self.max_count) + '}'
837            elif self.rules == 'OPTIONAL':
838                if null_init or not self.default_has:
839                    outer_init = 'false, ' + inner_init
840                else:
841                    outer_init = 'true, ' + inner_init
842            else:
843                outer_init = inner_init
844        elif self.allocation == 'POINTER':
845            if self.rules == 'REPEATED':
846                outer_init = '0, NULL'
847            else:
848                outer_init = 'NULL'
849        elif self.allocation == 'CALLBACK':
850            if self.pbtype == 'EXTENSION':
851                outer_init = 'NULL'
852            else:
853                outer_init = '{{NULL}, NULL}'
854
855        if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']:
856            outer_init = '{{NULL}, NULL}, ' + outer_init
857
858        return outer_init
859
860    def tags(self):
861        '''Return the #define for the tag number of this field.'''
862        identifier = Globals.naming_style.define_name('%s_%s_tag' % (self.struct_name, self.name))
863        return '#define %-40s %d\n' % (identifier, self.tag)
864
865    def fieldlist(self):
866        '''Return the FIELDLIST macro entry for this field.
867        Format is: X(a, ATYPE, HTYPE, LTYPE, field_name, tag)
868        '''
869        name = Globals.naming_style.var_name(self.name)
870
871        if self.rules == "ONEOF":
872          # For oneofs, make a tuple of the union name, union member name,
873          # and the name inside the parent struct.
874          if not self.anonymous:
875            name = '(%s,%s,%s)' % (
876                Globals.naming_style.var_name(self.union_name),
877                Globals.naming_style.var_name(self.name),
878                Globals.naming_style.var_name(self.union_name) + '.' +
879                Globals.naming_style.var_name(self.name))
880          else:
881            name = '(%s,%s,%s)' % (
882                Globals.naming_style.var_name(self.union_name),
883                Globals.naming_style.var_name(self.name),
884                Globals.naming_style.var_name(self.name))
885
886        return '%s(%s, %-9s %-9s %-9s %-16s %3d)' % (self.macro_x_param,
887                                                     self.macro_a_param,
888                                                     self.allocation + ',',
889                                                     self.rules + ',',
890                                                     self.pbtype + ',',
891                                                     name + ',',
892                                                     self.tag)
893
894    def data_size(self, dependencies):
895        '''Return estimated size of this field in the C struct.
896        This is used to try to automatically pick right descriptor size.
897        If the estimate is wrong, it will result in compile time error and
898        user having to specify descriptor_width option.
899        '''
900        if self.allocation == 'POINTER' or self.pbtype == 'EXTENSION':
901            size = 8
902            alignment = 8
903        elif self.allocation == 'CALLBACK':
904            size = 16
905            alignment = 8
906        elif self.pbtype in ['MESSAGE', 'MSG_W_CB']:
907            alignment = 8
908            if str(self.submsgname) in dependencies:
909                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
910                size = dependencies[str(self.submsgname)].data_size(other_dependencies)
911            else:
912                size = 256 # Message is in other file, this is reasonable guess for most cases
913                sys.stderr.write('Could not determine size for submessage %s, using default %d\n' % (self.submsgname, size))
914
915            if self.pbtype == 'MSG_W_CB':
916                size += 16
917        elif self.pbtype in ['STRING', 'FIXED_LENGTH_BYTES']:
918            size = self.max_size
919            alignment = 4
920        elif self.pbtype == 'BYTES':
921            size = self.max_size + 4
922            alignment = 4
923        elif self.data_item_size is not None:
924            size = self.data_item_size
925            alignment = 4
926            if self.data_item_size >= 8:
927                alignment = 8
928        else:
929            raise Exception("Unhandled field type: %s" % self.pbtype)
930
931        if self.rules in ['REPEATED', 'FIXARRAY'] and self.allocation == 'STATIC':
932            size *= self.max_count
933
934        if self.rules not in ('REQUIRED', 'SINGULAR'):
935            size += 4
936
937        if size % alignment != 0:
938            # Estimate how much alignment requirements will increase the size.
939            size += alignment - (size % alignment)
940
941        return size
942
943    def encoded_size(self, dependencies):
944        '''Return the maximum size that this field can take when encoded,
945        including the field tag. If the size cannot be determined, returns
946        None.'''
947
948        if self.allocation != 'STATIC':
949            return None
950
951        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
952            encsize = None
953            if str(self.submsgname) in dependencies:
954                submsg = dependencies[str(self.submsgname)]
955                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
956                encsize = submsg.encoded_size(other_dependencies)
957
958                my_msg = dependencies.get(str(self.struct_name))
959                external = (not my_msg or submsg.protofile != my_msg.protofile)
960
961                if encsize and encsize.symbols and external:
962                    # Couldn't fully resolve the size of a dependency from
963                    # another file. Instead of including the symbols directly,
964                    # just use the #define SubMessage_size from the header.
965                    encsize = None
966
967                if encsize is not None:
968                    # Include submessage length prefix
969                    encsize += varint_max_size(encsize.upperlimit())
970                elif not external:
971                    # The dependency is from the same file and size cannot be
972                    # determined for it, thus we know it will not be possible
973                    # in runtime either.
974                    return None
975
976            if encsize is None:
977                # Submessage or its size cannot be found.
978                # This can occur if submessage is defined in different
979                # file, and it or its .options could not be found.
980                # Instead of direct numeric value, reference the size that
981                # has been #defined in the other file.
982                encsize = EncodedSize(self.submsgname + 'size')
983
984                # We will have to make a conservative assumption on the length
985                # prefix size, though.
986                encsize += 5
987
988        elif self.pbtype in ['ENUM', 'UENUM']:
989            if str(self.ctype) in dependencies:
990                enumtype = dependencies[str(self.ctype)]
991                encsize = enumtype.encoded_size()
992            else:
993                # Conservative assumption
994                encsize = 10
995
996        elif self.enc_size is None:
997            raise RuntimeError("Could not determine encoded size for %s.%s"
998                               % (self.struct_name, self.name))
999        else:
1000            encsize = EncodedSize(self.enc_size)
1001
1002        encsize += varint_max_size(self.tag << 3) # Tag + wire type
1003
1004        if self.rules in ['REPEATED', 'FIXARRAY']:
1005            # Decoders must be always able to handle unpacked arrays.
1006            # Therefore we have to reserve space for it, even though
1007            # we emit packed arrays ourselves. For length of 1, packed
1008            # arrays are larger however so we need to add allowance
1009            # for the length byte.
1010            encsize *= self.max_count
1011
1012            if self.max_count == 1:
1013                encsize += 1
1014
1015        return encsize
1016
1017    def has_callbacks(self):
1018        return self.allocation == 'CALLBACK'
1019
1020    def requires_custom_field_callback(self):
1021        return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t'
1022
1023class ExtensionRange(Field):
1024    def __init__(self, struct_name, range_start, field_options):
1025        '''Implements a special pb_extension_t* field in an extensible message
1026        structure. The range_start signifies the index at which the extensions
1027        start. Not necessarily all tags above this are extensions, it is merely
1028        a speed optimization.
1029        '''
1030        self.tag = range_start
1031        self.struct_name = struct_name
1032        self.name = 'extensions'
1033        self.pbtype = 'EXTENSION'
1034        self.rules = 'OPTIONAL'
1035        self.allocation = 'CALLBACK'
1036        self.ctype = 'pb_extension_t'
1037        self.array_decl = ''
1038        self.default = None
1039        self.max_size = 0
1040        self.max_count = 0
1041        self.data_item_size = 0
1042        self.fixed_count = False
1043        self.callback_datatype = 'pb_extension_t*'
1044
1045    def requires_custom_field_callback(self):
1046        return False
1047
1048    def __str__(self):
1049        return '    pb_extension_t *extensions;'
1050
1051    def types(self):
1052        return ''
1053
1054    def tags(self):
1055        return ''
1056
1057    def encoded_size(self, dependencies):
1058        # We exclude extensions from the count, because they cannot be known
1059        # until runtime. Other option would be to return None here, but this
1060        # way the value remains useful if extensions are not used.
1061        return EncodedSize(0)
1062
1063class ExtensionField(Field):
1064    def __init__(self, fullname, desc, field_options):
1065        self.fullname = fullname
1066        self.extendee_name = names_from_type_name(desc.extendee)
1067        Field.__init__(self, self.fullname + "extmsg", desc, field_options)
1068
1069        if self.rules != 'OPTIONAL':
1070            self.skip = True
1071        else:
1072            self.skip = False
1073            self.rules = 'REQUIRED' # We don't really want the has_field for extensions
1074            # currently no support for comments for extension fields => provide (), {}
1075            self.msg = Message(self.fullname + "extmsg", None, field_options, (), {})
1076            self.msg.fields.append(self)
1077
1078    def tags(self):
1079        '''Return the #define for the tag number of this field.'''
1080        identifier = Globals.naming_style.define_name('%s_tag' % (self.fullname))
1081        return '#define %-40s %d\n' % (identifier, self.tag)
1082
1083    def extension_decl(self):
1084        '''Declaration of the extension type in the .pb.h file'''
1085        if self.skip:
1086            msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname
1087            msg +='   type of extension fields is currently supported. */\n'
1088            return msg
1089
1090        return ('extern const pb_extension_type_t %s; /* field type: %s */\n' %
1091            (Globals.naming_style.var_name(self.fullname), str(self).strip()))
1092
1093    def extension_def(self, dependencies):
1094        '''Definition of the extension type in the .pb.c file'''
1095
1096        if self.skip:
1097            return ''
1098
1099        result = "/* Definition for extension field %s */\n" % self.fullname
1100        result += str(self.msg)
1101        result += self.msg.fields_declaration(dependencies)
1102        result += 'pb_byte_t %s_default[] = {0x00};\n' % self.msg.name
1103        result += self.msg.fields_definition(dependencies)
1104        result += 'const pb_extension_type_t %s = {\n' % Globals.naming_style.var_name(self.fullname)
1105        result += '    NULL,\n'
1106        result += '    NULL,\n'
1107        result += '    &%s_msg\n' % Globals.naming_style.type_name(self.msg.name)
1108        result += '};\n'
1109        return result
1110
1111
1112# ---------------------------------------------------------------------------
1113#                   Generation of oneofs (unions)
1114# ---------------------------------------------------------------------------
1115
1116class OneOf(Field):
1117    def __init__(self, struct_name, oneof_desc, oneof_options):
1118        self.struct_name = struct_name
1119        self.name = oneof_desc.name
1120        self.ctype = 'union'
1121        self.pbtype = 'oneof'
1122        self.fields = []
1123        self.allocation = 'ONEOF'
1124        self.default = None
1125        self.rules = 'ONEOF'
1126        self.anonymous = oneof_options.anonymous_oneof
1127        self.sort_by_tag = oneof_options.sort_by_tag
1128        self.has_msg_cb = False
1129
1130    def add_field(self, field):
1131        field.union_name = self.name
1132        field.rules = 'ONEOF'
1133        field.anonymous = self.anonymous
1134        self.fields.append(field)
1135
1136        if self.sort_by_tag:
1137            self.fields.sort()
1138
1139        if field.pbtype == 'MSG_W_CB':
1140            self.has_msg_cb = True
1141
1142        # Sort by the lowest tag number inside union
1143        self.tag = min([f.tag for f in self.fields])
1144
1145    def __str__(self):
1146        result = ''
1147        if self.fields:
1148            if self.has_msg_cb:
1149                result += '    pb_callback_t cb_' + Globals.naming_style.var_name(self.name) + ';\n'
1150
1151            result += '    pb_size_t which_' + Globals.naming_style.var_name(self.name) + ";\n"
1152            result += '    union {\n'
1153            for f in self.fields:
1154                result += '    ' + str(f).replace('\n', '\n    ') + '\n'
1155            if self.anonymous:
1156                result += '    };'
1157            else:
1158                result += '    } ' + Globals.naming_style.var_name(self.name) + ';'
1159        return result
1160
1161    def types(self):
1162        return ''.join([f.types() for f in self.fields])
1163
1164    def get_dependencies(self):
1165        deps = []
1166        for f in self.fields:
1167            deps += f.get_dependencies()
1168        return deps
1169
1170    def get_initializer(self, null_init):
1171        if self.has_msg_cb:
1172            return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}'
1173        else:
1174            return '0, {' + self.fields[0].get_initializer(null_init) + '}'
1175
1176    def tags(self):
1177        return ''.join([f.tags() for f in self.fields])
1178
1179    def data_size(self, dependencies):
1180        return max(f.data_size(dependencies) for f in self.fields)
1181
1182    def encoded_size(self, dependencies):
1183        '''Returns the size of the largest oneof field.'''
1184        largest = 0
1185        dynamic_sizes = {}
1186        for f in self.fields:
1187            size = EncodedSize(f.encoded_size(dependencies))
1188            if size is None or size.value is None:
1189                return None
1190            elif size.symbols:
1191                dynamic_sizes[f.tag] = size
1192            elif size.value > largest:
1193                largest = size.value
1194
1195        if not dynamic_sizes:
1196            # Simple case, all sizes were known at generator time
1197            return EncodedSize(largest)
1198
1199        if largest > 0:
1200            # Some sizes were known, some were not
1201            dynamic_sizes[0] = EncodedSize(largest)
1202
1203        # Couldn't find size for submessage at generation time,
1204        # have to rely on macro resolution at compile time.
1205        if len(dynamic_sizes) == 1:
1206            # Only one symbol was needed
1207            return list(dynamic_sizes.values())[0]
1208        else:
1209            # Use sizeof(union{}) construct to find the maximum size of
1210            # submessages.
1211            union_name = "%s_%s_size_union" % (self.struct_name, self.name)
1212            union_def = 'union %s {%s};\n' % (union_name, ' '.join('char f%d[%s];' % (k, s) for k,s in dynamic_sizes.items()))
1213            required_defs = list(itertools.chain.from_iterable(s.required_defines for k,s in dynamic_sizes.items()))
1214            return EncodedSize(0, ['sizeof(union %s)' % union_name], [union_def], required_defs)
1215
1216    def has_callbacks(self):
1217        return bool([f for f in self.fields if f.has_callbacks()])
1218
1219    def requires_custom_field_callback(self):
1220        return bool([f for f in self.fields if f.requires_custom_field_callback()])
1221
1222# ---------------------------------------------------------------------------
1223#                   Generation of messages (structures)
1224# ---------------------------------------------------------------------------
1225
1226
1227class Message(ProtoElement):
1228    def __init__(self, names, desc, message_options, element_path, comments):
1229        super(Message, self).__init__(element_path, comments)
1230        self.name = names
1231        self.fields = []
1232        self.oneofs = {}
1233        self.desc = desc
1234        self.math_include_required = False
1235        self.packed = message_options.packed_struct
1236        self.descriptorsize = message_options.descriptorsize
1237
1238        if message_options.msgid:
1239            self.msgid = message_options.msgid
1240
1241        if desc is not None:
1242            self.load_fields(desc, message_options)
1243
1244        self.callback_function = message_options.callback_function
1245        if not message_options.HasField('callback_function'):
1246            # Automatically assign a per-message callback if any field has
1247            # a special callback_datatype.
1248            for field in self.fields:
1249                if field.requires_custom_field_callback():
1250                    self.callback_function = "%s_callback" % self.name
1251                    break
1252
1253    def load_fields(self, desc, message_options):
1254        '''Load field list from DescriptorProto'''
1255
1256        no_unions = []
1257
1258        if hasattr(desc, 'oneof_decl'):
1259            for i, f in enumerate(desc.oneof_decl):
1260                oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name)
1261                if oneof_options.no_unions:
1262                    no_unions.append(i) # No union, but add fields normally
1263                elif oneof_options.type == nanopb_pb2.FT_IGNORE:
1264                    pass # No union and skip fields also
1265                else:
1266                    oneof = OneOf(self.name, f, oneof_options)
1267                    self.oneofs[i] = oneof
1268        else:
1269            sys.stderr.write('Note: This Python protobuf library has no OneOf support\n')
1270
1271        for index, f in enumerate(desc.field):
1272            field_options = get_nanopb_suboptions(f, message_options, self.name + f.name)
1273            if field_options.type == nanopb_pb2.FT_IGNORE:
1274                continue
1275
1276            if field_options.descriptorsize > self.descriptorsize:
1277                self.descriptorsize = field_options.descriptorsize
1278
1279            field = Field(self.name, f, field_options, self.element_path + (ProtoElement.FIELD, index), self.comments)
1280            if hasattr(f, 'oneof_index') and f.HasField('oneof_index'):
1281                if hasattr(f, 'proto3_optional') and f.proto3_optional:
1282                    no_unions.append(f.oneof_index)
1283
1284                if f.oneof_index in no_unions:
1285                    self.fields.append(field)
1286                elif f.oneof_index in self.oneofs:
1287                    self.oneofs[f.oneof_index].add_field(field)
1288
1289                    if self.oneofs[f.oneof_index] not in self.fields:
1290                        self.fields.append(self.oneofs[f.oneof_index])
1291            else:
1292                self.fields.append(field)
1293
1294            if field.math_include_required:
1295                self.math_include_required = True
1296
1297        if len(desc.extension_range) > 0:
1298            field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
1299            range_start = min([r.start for r in desc.extension_range])
1300            if field_options.type != nanopb_pb2.FT_IGNORE:
1301                self.fields.append(ExtensionRange(self.name, range_start, field_options))
1302
1303        if message_options.sort_by_tag:
1304            self.fields.sort()
1305
1306    def get_dependencies(self):
1307        '''Get list of type names that this structure refers to.'''
1308        deps = []
1309        for f in self.fields:
1310            deps += f.get_dependencies()
1311        return deps
1312
1313    def __repr__(self):
1314        return 'Message(%s)' % self.name
1315
1316    def __str__(self):
1317        leading_comment, trailing_comment = self.get_comments()
1318
1319        result = ''
1320        if leading_comment:
1321            result = '%s\n' % leading_comment
1322
1323        result += 'typedef struct %s {' % Globals.naming_style.struct_name(self.name)
1324        if trailing_comment:
1325            result += " " + trailing_comment
1326
1327        result += '\n'
1328
1329        if not self.fields:
1330            # Empty structs are not allowed in C standard.
1331            # Therefore add a dummy field if an empty message occurs.
1332            result += '    char dummy_field;'
1333
1334        result += '\n'.join([str(f) for f in self.fields])
1335
1336        if Globals.protoc_insertion_points:
1337            result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name
1338
1339        result += '\n}'
1340
1341        if self.packed:
1342            result += ' pb_packed'
1343
1344        result += ' %s;' % Globals.naming_style.type_name(self.name)
1345
1346        if self.packed:
1347            result = 'PB_PACKED_STRUCT_START\n' + result
1348            result += '\nPB_PACKED_STRUCT_END'
1349
1350        return result + '\n'
1351
1352    def types(self):
1353        return ''.join([f.types() for f in self.fields])
1354
1355    def get_initializer(self, null_init):
1356        if not self.fields:
1357            return '{0}'
1358
1359        parts = []
1360        for field in self.fields:
1361            parts.append(field.get_initializer(null_init))
1362        return '{' + ', '.join(parts) + '}'
1363
1364    def count_required_fields(self):
1365        '''Returns number of required fields inside this message'''
1366        count = 0
1367        for f in self.fields:
1368            if not isinstance(f, OneOf):
1369                if f.rules == 'REQUIRED':
1370                    count += 1
1371        return count
1372
1373    def all_fields(self):
1374        '''Iterate over all fields in this message, including nested OneOfs.'''
1375        for f in self.fields:
1376            if isinstance(f, OneOf):
1377                for f2 in f.fields:
1378                    yield f2
1379            else:
1380                yield f
1381
1382
1383    def field_for_tag(self, tag):
1384        '''Given a tag number, return the Field instance.'''
1385        for field in self.all_fields():
1386            if field.tag == tag:
1387                return field
1388        return None
1389
1390    def count_all_fields(self):
1391        '''Count the total number of fields in this message.'''
1392        count = 0
1393        for f in self.fields:
1394            if isinstance(f, OneOf):
1395                count += len(f.fields)
1396            else:
1397                count += 1
1398        return count
1399
1400    def fields_declaration(self, dependencies):
1401        '''Return X-macro declaration of all fields in this message.'''
1402        Field.macro_x_param = 'X'
1403        Field.macro_a_param = 'a'
1404        while any(field.name == Field.macro_x_param for field in self.all_fields()):
1405            Field.macro_x_param += '_'
1406        while any(field.name == Field.macro_a_param for field in self.all_fields()):
1407            Field.macro_a_param += '_'
1408
1409        # Field descriptor array must be sorted by tag number, pb_common.c relies on it.
1410        sorted_fields = list(self.all_fields())
1411        sorted_fields.sort(key = lambda x: x.tag)
1412
1413        result = '#define %s_FIELDLIST(%s, %s) \\\n' % (
1414            Globals.naming_style.define_name(self.name),
1415            Field.macro_x_param,
1416            Field.macro_a_param)
1417        result += ' \\\n'.join(x.fieldlist() for x in sorted_fields)
1418        result += '\n'
1419
1420        has_callbacks = bool([f for f in self.fields if f.has_callbacks()])
1421        if has_callbacks:
1422            if self.callback_function != 'pb_default_field_callback':
1423                result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function
1424            result += "#define %s_CALLBACK %s\n" % (
1425                Globals.naming_style.define_name(self.name),
1426                self.callback_function)
1427        else:
1428            result += "#define %s_CALLBACK NULL\n" % Globals.naming_style.define_name(self.name)
1429
1430        defval = self.default_value(dependencies)
1431        if defval:
1432            hexcoded = ''.join("\\x%02x" % ord(defval[i:i+1]) for i in range(len(defval)))
1433            result += '#define %s_DEFAULT (const pb_byte_t*)"%s\\x00"\n' % (
1434                Globals.naming_style.define_name(self.name),
1435                hexcoded)
1436        else:
1437            result += '#define %s_DEFAULT NULL\n' % Globals.naming_style.define_name(self.name)
1438
1439        for field in sorted_fields:
1440            if field.pbtype in ['MESSAGE', 'MSG_W_CB']:
1441                if field.rules == 'ONEOF':
1442                    result += "#define %s_%s_%s_MSGTYPE %s\n" % (
1443                        Globals.naming_style.type_name(self.name),
1444                        Globals.naming_style.var_name(field.union_name),
1445                        Globals.naming_style.var_name(field.name),
1446                        Globals.naming_style.type_name(field.ctype)
1447                    )
1448                else:
1449                    result += "#define %s_%s_MSGTYPE %s\n" % (
1450                        Globals.naming_style.type_name(self.name),
1451                        Globals.naming_style.var_name(field.name),
1452                        Globals.naming_style.type_name(field.ctype)
1453                    )
1454
1455        return result
1456
1457    def enumtype_defines(self):
1458        '''Defines to allow user code to refer to enum type of a specific field'''
1459        result = ''
1460        for field in self.all_fields():
1461            if field.pbtype in ['ENUM', "UENUM"]:
1462                if field.rules == 'ONEOF':
1463                    result += "#define %s_%s_%s_ENUMTYPE %s\n" % (
1464                        Globals.naming_style.type_name(self.name),
1465                        Globals.naming_style.var_name(field.union_name),
1466                        Globals.naming_style.var_name(field.name),
1467                        Globals.naming_style.type_name(field.ctype)
1468                    )
1469                else:
1470                    result += "#define %s_%s_ENUMTYPE %s\n" % (
1471                        Globals.naming_style.type_name(self.name),
1472                        Globals.naming_style.var_name(field.name),
1473                        Globals.naming_style.type_name(field.ctype)
1474                    )
1475
1476        return result
1477
1478    def fields_declaration_cpp_lookup(self):
1479        result = 'template <>\n'
1480        result += 'struct MessageDescriptor<%s> {\n' % (self.name)
1481        result += '    static PB_INLINE_CONSTEXPR const pb_size_t fields_array_length = %d;\n' % (self.count_all_fields())
1482        result += '    static inline const pb_msgdesc_t* fields() {\n'
1483        result += '        return &%s_msg;\n' % (self.name)
1484        result += '    }\n'
1485        result += '};'
1486        return result
1487
1488    def fields_definition(self, dependencies):
1489        '''Return the field descriptor definition that goes in .pb.c file.'''
1490        width = self.required_descriptor_width(dependencies)
1491        if width == 1:
1492          width = 'AUTO'
1493
1494        result = 'PB_BIND(%s, %s, %s)\n' % (
1495            Globals.naming_style.define_name(self.name),
1496            Globals.naming_style.type_name(self.name),
1497            width)
1498        return result
1499
1500    def required_descriptor_width(self, dependencies):
1501        '''Estimate how many words are necessary for each field descriptor.'''
1502        if self.descriptorsize != nanopb_pb2.DS_AUTO:
1503            return int(self.descriptorsize)
1504
1505        if not self.fields:
1506          return 1
1507
1508        max_tag = max(field.tag for field in self.all_fields())
1509        max_offset = self.data_size(dependencies)
1510        max_arraysize = max((field.max_count or 0) for field in self.all_fields())
1511        max_datasize = max(field.data_size(dependencies) for field in self.all_fields())
1512
1513        if max_arraysize > 0xFFFF:
1514            return 8
1515        elif (max_tag > 0x3FF or max_offset > 0xFFFF or
1516              max_arraysize > 0x0FFF or max_datasize > 0x0FFF):
1517            return 4
1518        elif max_tag > 0x3F or max_offset > 0xFF:
1519            return 2
1520        else:
1521            # NOTE: Macro logic in pb.h ensures that width 1 will
1522            # be raised to 2 automatically for string/submsg fields
1523            # and repeated fields. Thus only tag and offset need to
1524            # be checked.
1525            return 1
1526
1527    def data_size(self, dependencies):
1528        '''Return approximate sizeof(struct) in the compiled code.'''
1529        return sum(f.data_size(dependencies) for f in self.fields)
1530
1531    def encoded_size(self, dependencies):
1532        '''Return the maximum size that this message can take when encoded.
1533        If the size cannot be determined, returns None.
1534        '''
1535        size = EncodedSize(0)
1536        for field in self.fields:
1537            fsize = field.encoded_size(dependencies)
1538            if fsize is None:
1539                return None
1540            size += fsize
1541
1542        return size
1543
1544    def default_value(self, dependencies):
1545        '''Generate serialized protobuf message that contains the
1546        default values for optional fields.'''
1547
1548        if not self.desc:
1549            return b''
1550
1551        if self.desc.options.map_entry:
1552            return b''
1553
1554        optional_only = copy.deepcopy(self.desc)
1555
1556        # Remove fields without default values
1557        # The iteration is done in reverse order to avoid remove() messing up iteration.
1558        for field in reversed(list(optional_only.field)):
1559            field.ClearField(str('extendee'))
1560            parsed_field = self.field_for_tag(field.number)
1561            if parsed_field is None or parsed_field.allocation != 'STATIC':
1562                optional_only.field.remove(field)
1563            elif (field.label == FieldD.LABEL_REPEATED or
1564                  field.type == FieldD.TYPE_MESSAGE):
1565                optional_only.field.remove(field)
1566            elif hasattr(field, 'oneof_index') and field.HasField('oneof_index'):
1567                optional_only.field.remove(field)
1568            elif field.type == FieldD.TYPE_ENUM:
1569                # The partial descriptor doesn't include the enum type
1570                # so we fake it with int64.
1571                enumname = names_from_type_name(field.type_name)
1572                try:
1573                    enumtype = dependencies[str(enumname)]
1574                except KeyError:
1575                    raise Exception("Could not find enum type %s while generating default values for %s.\n" % (enumname, self.name)
1576                                    + "Try passing all source files to generator at once, or use -I option.")
1577
1578                if not isinstance(enumtype, Enum):
1579                    raise Exception("Expected enum type as %s, got %s" % (enumname, repr(enumtype)))
1580
1581                if field.HasField('default_value'):
1582                    defvals = [v for n,v in enumtype.values if n.parts[-1] == field.default_value]
1583                else:
1584                    # If no default is specified, the default is the first value.
1585                    defvals = [v for n,v in enumtype.values]
1586                if defvals and defvals[0] != 0:
1587                    field.type = FieldD.TYPE_INT64
1588                    field.default_value = str(defvals[0])
1589                    field.ClearField(str('type_name'))
1590                else:
1591                    optional_only.field.remove(field)
1592            elif not field.HasField('default_value'):
1593                optional_only.field.remove(field)
1594
1595        if len(optional_only.field) == 0:
1596            return b''
1597
1598        optional_only.ClearField(str('oneof_decl'))
1599        optional_only.ClearField(str('nested_type'))
1600        optional_only.ClearField(str('extension'))
1601        optional_only.ClearField(str('enum_type'))
1602        optional_only.name += str(id(self))
1603
1604        desc = google.protobuf.descriptor.MakeDescriptor(optional_only)
1605        msg = reflection.MakeClass(desc)()
1606
1607        for field in optional_only.field:
1608            if field.type == FieldD.TYPE_STRING:
1609                setattr(msg, field.name, field.default_value)
1610            elif field.type == FieldD.TYPE_BYTES:
1611                setattr(msg, field.name, codecs.escape_decode(field.default_value)[0])
1612            elif field.type in [FieldD.TYPE_FLOAT, FieldD.TYPE_DOUBLE]:
1613                setattr(msg, field.name, float(field.default_value))
1614            elif field.type == FieldD.TYPE_BOOL:
1615                setattr(msg, field.name, field.default_value == 'true')
1616            else:
1617                setattr(msg, field.name, int(field.default_value))
1618
1619        return msg.SerializeToString()
1620
1621
1622# ---------------------------------------------------------------------------
1623#                    Processing of entire .proto files
1624# ---------------------------------------------------------------------------
1625
1626def iterate_messages(desc, flatten = False, names = Names(), comment_path = ()):
1627    '''Recursively find all messages. For each, yield name, DescriptorProto, comment_path.'''
1628    if hasattr(desc, 'message_type'):
1629        submsgs = desc.message_type
1630        comment_path += (ProtoElement.MESSAGE,)
1631    else:
1632        submsgs = desc.nested_type
1633        comment_path += (ProtoElement.NESTED_TYPE,)
1634
1635    for idx, submsg in enumerate(submsgs):
1636        sub_names = names + submsg.name
1637        sub_path = comment_path + (idx,)
1638        if flatten:
1639            yield Names(submsg.name), submsg, sub_path
1640        else:
1641            yield sub_names, submsg, sub_path
1642
1643        for x in iterate_messages(submsg, flatten, sub_names, sub_path):
1644            yield x
1645
1646def iterate_extensions(desc, flatten = False, names = Names()):
1647    '''Recursively find all extensions.
1648    For each, yield name, FieldDescriptorProto.
1649    '''
1650    for extension in desc.extension:
1651        yield names, extension
1652
1653    for subname, subdesc, comment_path in iterate_messages(desc, flatten, names):
1654        for extension in subdesc.extension:
1655            yield subname, extension
1656
1657def sort_dependencies(messages):
1658    '''Sort a list of Messages based on dependencies.'''
1659
1660    # Construct first level list of dependencies
1661    dependencies = {}
1662    for message in messages:
1663        dependencies[str(message.name)] = set(message.get_dependencies())
1664
1665    # Emit messages after all their dependencies have been processed
1666    remaining = list(messages)
1667    remainset = set(str(m.name) for m in remaining)
1668    while remaining:
1669        for candidate in remaining:
1670            if not remainset.intersection(dependencies[str(candidate.name)]):
1671                remaining.remove(candidate)
1672                remainset.remove(str(candidate.name))
1673                yield candidate
1674                break
1675        else:
1676            sys.stderr.write("Circular dependency in messages: " + ', '.join(remainset) + " (consider changing to FT_POINTER or FT_CALLBACK)\n")
1677            candidate = remaining.pop(0)
1678            remainset.remove(str(candidate.name))
1679            yield candidate
1680
1681def make_identifier(headername):
1682    '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9'''
1683    result = ""
1684    for c in headername.upper():
1685        if c.isalnum():
1686            result += c
1687        else:
1688            result += '_'
1689    return result
1690
1691class MangleNames:
1692    '''Handles conversion of type names according to mangle_names option:
1693    M_NONE = 0; // Default, no typename mangling
1694    M_STRIP_PACKAGE = 1; // Strip current package name
1695    M_FLATTEN = 2; // Only use last path component
1696    M_PACKAGE_INITIALS = 3; // Replace the package name by the initials
1697    '''
1698    def __init__(self, fdesc, file_options):
1699        self.file_options = file_options
1700        self.mangle_names = file_options.mangle_names
1701        self.flatten = (self.mangle_names == nanopb_pb2.M_FLATTEN)
1702        self.strip_prefix = None
1703        self.replacement_prefix = None
1704        self.name_mapping = {}
1705        self.reverse_name_mapping = {}
1706        self.canonical_base = Names(fdesc.package.split('.'))
1707
1708        if self.mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1709            self.strip_prefix = "." + fdesc.package
1710        elif self.mangle_names == nanopb_pb2.M_PACKAGE_INITIALS:
1711            self.strip_prefix = "." + fdesc.package
1712            self.replacement_prefix = ""
1713            for part in fdesc.package.split("."):
1714                self.replacement_prefix += part[0]
1715        elif file_options.package:
1716            self.strip_prefix = "." + fdesc.package
1717            self.replacement_prefix = file_options.package
1718
1719        if self.strip_prefix == '.':
1720            self.strip_prefix = ''
1721
1722        if self.replacement_prefix is not None:
1723            self.base_name = Names(self.replacement_prefix.split('.'))
1724        elif fdesc.package:
1725            self.base_name = Names(fdesc.package.split('.'))
1726        else:
1727            self.base_name = Names()
1728
1729    def create_name(self, names):
1730        '''Create name for a new message / enum.
1731        Argument can be either string or Names instance.
1732        '''
1733        if str(names) not in self.name_mapping:
1734            if self.mangle_names in (nanopb_pb2.M_NONE, nanopb_pb2.M_PACKAGE_INITIALS):
1735                new_name = self.base_name + names
1736            elif self.mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1737                new_name = Names(names)
1738            elif isinstance(names, Names):
1739                new_name = Names(names.parts[-1])
1740            else:
1741                new_name = Names(names)
1742
1743            if str(new_name) in self.reverse_name_mapping:
1744                sys.stderr.write("Warning: Duplicate name with mangle_names=%s: %s and %s map to %s\n" %
1745                    (self.mangle_names, self.reverse_name_mapping[str(new_name)], names, new_name))
1746
1747            self.name_mapping[str(names)] = new_name
1748            self.reverse_name_mapping[str(new_name)] = self.canonical_base + names
1749
1750        return self.name_mapping[str(names)]
1751
1752    def mangle_field_typename(self, typename):
1753        '''Mangle type name for a submessage / enum crossreference.
1754        Argument is a string.
1755        '''
1756        if self.mangle_names == nanopb_pb2.M_FLATTEN:
1757            return "." + typename.split(".")[-1]
1758
1759        if self.strip_prefix is not None and typename.startswith(self.strip_prefix):
1760            if self.replacement_prefix is not None:
1761                return "." + self.replacement_prefix + typename[len(self.strip_prefix):]
1762            else:
1763                return typename[len(self.strip_prefix):]
1764
1765        if self.file_options.package:
1766            return "." + self.replacement_prefix + typename
1767
1768        return typename
1769
1770    def unmangle(self, names):
1771        return self.reverse_name_mapping.get(str(names), names)
1772
1773class ProtoFile:
1774    def __init__(self, fdesc, file_options):
1775        '''Takes a FileDescriptorProto and parses it.'''
1776        self.fdesc = fdesc
1777        self.file_options = file_options
1778        self.dependencies = {}
1779        self.math_include_required = False
1780        self.parse()
1781        for message in self.messages:
1782            if message.math_include_required:
1783                self.math_include_required = True
1784                break
1785
1786        # Some of types used in this file probably come from the file itself.
1787        # Thus it has implicit dependency on itself.
1788        self.add_dependency(self)
1789
1790    def parse(self):
1791        self.enums = []
1792        self.messages = []
1793        self.extensions = []
1794        self.manglenames = MangleNames(self.fdesc, self.file_options)
1795
1796        # process source code comment locations
1797        # ignores any locations that do not contain any comment information
1798        self.comment_locations = {
1799            tuple(location.path): location
1800            for location in self.fdesc.source_code_info.location
1801            if location.leading_comments or location.leading_detached_comments or location.trailing_comments
1802        }
1803
1804        for index, enum in enumerate(self.fdesc.enum_type):
1805            name = self.manglenames.create_name(enum.name)
1806            enum_options = get_nanopb_suboptions(enum, self.file_options, name)
1807            enum_path = (ProtoElement.ENUM, index)
1808            self.enums.append(Enum(name, enum, enum_options, enum_path, self.comment_locations))
1809
1810        for names, message, comment_path in iterate_messages(self.fdesc, self.manglenames.flatten):
1811            name = self.manglenames.create_name(names)
1812            message_options = get_nanopb_suboptions(message, self.file_options, name)
1813
1814            if message_options.skip_message:
1815                continue
1816
1817            message = copy.deepcopy(message)
1818            for field in message.field:
1819                if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1820                    field.type_name = self.manglenames.mangle_field_typename(field.type_name)
1821
1822            self.messages.append(Message(name, message, message_options, comment_path, self.comment_locations))
1823            for index, enum in enumerate(message.enum_type):
1824                name = self.manglenames.create_name(names + enum.name)
1825                enum_options = get_nanopb_suboptions(enum, message_options, name)
1826                enum_path = comment_path + (ProtoElement.NESTED_ENUM, index)
1827                self.enums.append(Enum(name, enum, enum_options, enum_path, self.comment_locations))
1828
1829        for names, extension in iterate_extensions(self.fdesc, self.manglenames.flatten):
1830            name = self.manglenames.create_name(names + extension.name)
1831            field_options = get_nanopb_suboptions(extension, self.file_options, name)
1832
1833            extension = copy.deepcopy(extension)
1834            if extension.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1835                extension.type_name = self.manglenames.mangle_field_typename(extension.type_name)
1836
1837            if field_options.type != nanopb_pb2.FT_IGNORE:
1838                self.extensions.append(ExtensionField(name, extension, field_options))
1839
1840    def add_dependency(self, other):
1841        for enum in other.enums:
1842            self.dependencies[str(enum.names)] = enum
1843            self.dependencies[str(other.manglenames.unmangle(enum.names))] = enum
1844            enum.protofile = other
1845
1846        for msg in other.messages:
1847            self.dependencies[str(msg.name)] = msg
1848            self.dependencies[str(other.manglenames.unmangle(msg.name))] = msg
1849            msg.protofile = other
1850
1851        # Fix field default values where enum short names are used.
1852        for enum in other.enums:
1853            if not enum.options.long_names:
1854                for message in self.messages:
1855                    for field in message.all_fields():
1856                        if field.default in enum.value_longnames:
1857                            idx = enum.value_longnames.index(field.default)
1858                            field.default = enum.values[idx][0]
1859
1860        # Fix field data types where enums have negative values.
1861        for enum in other.enums:
1862            if not enum.has_negative():
1863                for message in self.messages:
1864                    for field in message.all_fields():
1865                        if field.pbtype == 'ENUM' and field.ctype == enum.names:
1866                            field.pbtype = 'UENUM'
1867
1868    def generate_header(self, includes, headername, options):
1869        '''Generate content for a header file.
1870        Generates strings, which should be concatenated and stored to file.
1871        '''
1872
1873        yield '/* Automatically generated nanopb header */\n'
1874        if options.notimestamp:
1875            yield '/* Generated by %s */\n\n' % (nanopb_version)
1876        else:
1877            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1878
1879        if self.fdesc.package:
1880            symbol = make_identifier(self.fdesc.package + '_' + headername)
1881        else:
1882            symbol = make_identifier(headername)
1883        yield '#ifndef PB_%s_INCLUDED\n' % symbol
1884        yield '#define PB_%s_INCLUDED\n' % symbol
1885        if self.math_include_required:
1886            yield '#include <math.h>\n'
1887        try:
1888            yield options.libformat % ('pb.h')
1889        except TypeError:
1890            # no %s specified - use whatever was passed in as options.libformat
1891            yield options.libformat
1892        yield '\n'
1893
1894        for incfile in self.file_options.include:
1895            # allow including system headers
1896            if (incfile.startswith('<')):
1897                yield '#include %s\n' % incfile
1898            else:
1899                yield options.genformat % incfile
1900                yield '\n'
1901
1902        for incfile in includes:
1903            noext = os.path.splitext(incfile)[0]
1904            yield options.genformat % (noext + options.extension + options.header_extension)
1905            yield '\n'
1906
1907        if Globals.protoc_insertion_points:
1908            yield '/* @@protoc_insertion_point(includes) */\n'
1909
1910        yield '\n'
1911
1912        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
1913        yield '#error Regenerate this file with the current version of nanopb generator.\n'
1914        yield '#endif\n'
1915        yield '\n'
1916
1917        if self.enums:
1918            yield '/* Enum definitions */\n'
1919            for enum in self.enums:
1920                yield str(enum) + '\n\n'
1921
1922        if self.messages:
1923            yield '/* Struct definitions */\n'
1924            for msg in sort_dependencies(self.messages):
1925                yield msg.types()
1926                yield str(msg) + '\n'
1927            yield '\n'
1928
1929        if self.extensions:
1930            yield '/* Extensions */\n'
1931            for extension in self.extensions:
1932                yield extension.extension_decl()
1933            yield '\n'
1934
1935        yield '#ifdef __cplusplus\n'
1936        yield 'extern "C" {\n'
1937        yield '#endif\n\n'
1938
1939        if self.enums:
1940                yield '/* Helper constants for enums */\n'
1941                for enum in self.enums:
1942                    yield enum.auxiliary_defines() + '\n'
1943
1944                for msg in self.messages:
1945                    yield msg.enumtype_defines() + '\n'
1946                yield '\n'
1947
1948        if self.messages:
1949            yield '/* Initializer values for message structs */\n'
1950            for msg in self.messages:
1951                identifier = Globals.naming_style.define_name('%s_init_default' % msg.name)
1952                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
1953            for msg in self.messages:
1954                identifier = Globals.naming_style.define_name('%s_init_zero' % msg.name)
1955                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
1956            yield '\n'
1957
1958            yield '/* Field tags (for use in manual encoding/decoding) */\n'
1959            for msg in sort_dependencies(self.messages):
1960                for field in msg.fields:
1961                    yield field.tags()
1962            for extension in self.extensions:
1963                yield extension.tags()
1964            yield '\n'
1965
1966            yield '/* Struct field encoding specification for nanopb */\n'
1967            for msg in self.messages:
1968                yield msg.fields_declaration(self.dependencies) + '\n'
1969            for msg in self.messages:
1970                yield 'extern const pb_msgdesc_t %s_msg;\n' % Globals.naming_style.type_name(msg.name)
1971            yield '\n'
1972
1973            yield '/* Defines for backwards compatibility with code written before nanopb-0.4.0 */\n'
1974            for msg in self.messages:
1975              yield '#define %s &%s_msg\n' % (
1976                Globals.naming_style.define_name('%s_fields' % msg.name),
1977                Globals.naming_style.type_name(msg.name))
1978            yield '\n'
1979
1980            yield '/* Maximum encoded size of messages (where known) */\n'
1981            messagesizes = []
1982            for msg in self.messages:
1983                identifier = '%s_size' % msg.name
1984                messagesizes.append((identifier, msg.encoded_size(self.dependencies)))
1985
1986            # If we require a symbol from another file, put a preprocessor if statement
1987            # around it to prevent compilation errors if the symbol is not actually available.
1988            local_defines = [identifier for identifier, msize in messagesizes if msize is not None]
1989
1990            # emit size_unions, if any
1991            oneof_sizes = []
1992            for msg in self.messages:
1993                for f in msg.fields:
1994                    if isinstance(f, OneOf):
1995                        msize = f.encoded_size(self.dependencies)
1996                        if msize is not None:
1997                            oneof_sizes.append(msize)
1998            for msize in oneof_sizes:
1999                guard = msize.get_cpp_guard(local_defines)
2000                if guard:
2001                    yield guard
2002                yield msize.get_declarations()
2003                if guard:
2004                    yield '#endif\n'
2005
2006            guards = {}
2007            for identifier, msize in messagesizes:
2008                if msize is not None:
2009                    cpp_guard = msize.get_cpp_guard(local_defines)
2010                    if cpp_guard not in guards:
2011                        guards[cpp_guard] = set()
2012                    guards[cpp_guard].add('#define %-40s %s' % (
2013                        Globals.naming_style.define_name(identifier), msize))
2014                else:
2015                    yield '/* %s depends on runtime parameters */\n' % identifier
2016            for guard, values in guards.items():
2017                if guard:
2018                    yield guard
2019                for v in sorted(values):
2020                    yield v
2021                    yield '\n'
2022                if guard:
2023                    yield '#endif\n'
2024            yield '\n'
2025
2026            if [msg for msg in self.messages if hasattr(msg,'msgid')]:
2027              yield '/* Message IDs (where set with "msgid" option) */\n'
2028              for msg in self.messages:
2029                  if hasattr(msg,'msgid'):
2030                      yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
2031              yield '\n'
2032
2033              symbol = make_identifier(headername.split('.')[0])
2034              yield '#define %s_MESSAGES \\\n' % symbol
2035
2036              for msg in self.messages:
2037                  m = "-1"
2038                  msize = msg.encoded_size(self.dependencies)
2039                  if msize is not None:
2040                      m = msize
2041                  if hasattr(msg,'msgid'):
2042                      yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
2043              yield '\n'
2044
2045              for msg in self.messages:
2046                  if hasattr(msg,'msgid'):
2047                      yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
2048              yield '\n'
2049
2050        # Check if there is any name mangling active
2051        pairs = [x for x in self.manglenames.reverse_name_mapping.items() if str(x[0]) != str(x[1])]
2052        if pairs:
2053            yield '/* Mapping from canonical names (mangle_names or overridden package name) */\n'
2054            for shortname, longname in pairs:
2055                yield '#define %s %s\n' % (longname, shortname)
2056            yield '\n'
2057
2058        yield '#ifdef __cplusplus\n'
2059        yield '} /* extern "C" */\n'
2060        yield '#endif\n'
2061
2062        if options.cpp_descriptors:
2063            yield '\n'
2064            yield '#ifdef __cplusplus\n'
2065            yield '/* Message descriptors for nanopb */\n'
2066            yield 'namespace nanopb {\n'
2067            for msg in self.messages:
2068                yield msg.fields_declaration_cpp_lookup() + '\n'
2069            yield '}  // namespace nanopb\n'
2070            yield '\n'
2071            yield '#endif  /* __cplusplus */\n'
2072            yield '\n'
2073
2074        if Globals.protoc_insertion_points:
2075            yield '/* @@protoc_insertion_point(eof) */\n'
2076
2077        # End of header
2078        yield '\n#endif\n'
2079
2080    def generate_source(self, headername, options):
2081        '''Generate content for a source file.'''
2082
2083        yield '/* Automatically generated nanopb constant definitions */\n'
2084        if options.notimestamp:
2085            yield '/* Generated by %s */\n\n' % (nanopb_version)
2086        else:
2087            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
2088        yield options.genformat % (headername)
2089        yield '\n'
2090
2091        if Globals.protoc_insertion_points:
2092            yield '/* @@protoc_insertion_point(includes) */\n'
2093
2094        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
2095        yield '#error Regenerate this file with the current version of nanopb generator.\n'
2096        yield '#endif\n'
2097        yield '\n'
2098
2099        # Check if any messages exceed the 64 kB limit of 16-bit pb_size_t
2100        exceeds_64kB = []
2101        for msg in self.messages:
2102            size = msg.data_size(self.dependencies)
2103            if size >= 65536:
2104                exceeds_64kB.append(str(msg.name))
2105
2106        if exceeds_64kB:
2107            yield '\n/* The following messages exceed 64kB in size: ' + ', '.join(exceeds_64kB) + ' */\n'
2108            yield '\n/* The PB_FIELD_32BIT compilation option must be defined to support messages that exceed 64 kB in size. */\n'
2109            yield '#ifndef PB_FIELD_32BIT\n'
2110            yield '#error Enable PB_FIELD_32BIT to support messages exceeding 64kB in size: ' + ', '.join(exceeds_64kB) + '\n'
2111            yield '#endif\n'
2112
2113        # Generate the message field definitions (PB_BIND() call)
2114        for msg in self.messages:
2115            yield msg.fields_definition(self.dependencies) + '\n\n'
2116
2117        # Generate pb_extension_type_t definitions if extensions are used in proto file
2118        for ext in self.extensions:
2119            yield ext.extension_def(self.dependencies) + '\n'
2120
2121        # Generate enum_name function if enum_to_string option is defined
2122        for enum in self.enums:
2123            yield enum.enum_to_string_definition() + '\n'
2124
2125        # Add checks for numeric limits
2126        if self.messages:
2127            largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
2128            largest_count = largest_msg.count_required_fields()
2129            if largest_count > 64:
2130                yield '\n/* Check that missing required fields will be properly detected */\n'
2131                yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
2132                yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
2133                yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
2134                yield '#endif\n'
2135
2136        # Add check for sizeof(double)
2137        has_double = False
2138        for msg in self.messages:
2139            for field in msg.all_fields():
2140                if field.ctype == 'double':
2141                    has_double = True
2142
2143        if has_double:
2144            yield '\n'
2145            yield '#ifndef PB_CONVERT_DOUBLE_FLOAT\n'
2146            yield '/* On some platforms (such as AVR), double is really float.\n'
2147            yield ' * To be able to encode/decode double on these platforms, you need.\n'
2148            yield ' * to define PB_CONVERT_DOUBLE_FLOAT in pb.h or compiler command line.\n'
2149            yield ' */\n'
2150            yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
2151            yield '#endif\n'
2152
2153        yield '\n'
2154
2155        if Globals.protoc_insertion_points:
2156            yield '/* @@protoc_insertion_point(eof) */\n'
2157
2158# ---------------------------------------------------------------------------
2159#                    Options parsing for the .proto files
2160# ---------------------------------------------------------------------------
2161
2162from fnmatch import fnmatchcase
2163
2164def read_options_file(infile):
2165    '''Parse a separate options file to list:
2166        [(namemask, options), ...]
2167    '''
2168    results = []
2169    data = infile.read()
2170    data = re.sub(r'/\*.*?\*/', '', data, flags = re.MULTILINE)
2171    data = re.sub(r'//.*?$', '', data, flags = re.MULTILINE)
2172    data = re.sub(r'#.*?$', '', data, flags = re.MULTILINE)
2173    for i, line in enumerate(data.split('\n')):
2174        line = line.strip()
2175        if not line:
2176            continue
2177
2178        parts = line.split(None, 1)
2179
2180        if len(parts) < 2:
2181            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
2182                             "Option lines should have space between field name and options. " +
2183                             "Skipping line: '%s'\n" % line)
2184            sys.exit(1)
2185
2186        opts = nanopb_pb2.NanoPBOptions()
2187
2188        try:
2189            text_format.Merge(parts[1], opts)
2190        except Exception as e:
2191            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
2192                             "Unparsable option line: '%s'. " % line +
2193                             "Error: %s\n" % str(e))
2194            sys.exit(1)
2195        results.append((parts[0], opts))
2196
2197    return results
2198
2199def get_nanopb_suboptions(subdesc, options, name):
2200    '''Get copy of options, and merge information from subdesc.'''
2201    new_options = nanopb_pb2.NanoPBOptions()
2202    new_options.CopyFrom(options)
2203
2204    if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3":
2205        new_options.proto3 = True
2206
2207    # Handle options defined in a separate file
2208    dotname = '.'.join(name.parts)
2209    for namemask, options in Globals.separate_options:
2210        if fnmatchcase(dotname, namemask):
2211            Globals.matched_namemasks.add(namemask)
2212            new_options.MergeFrom(options)
2213
2214    # Handle options defined in .proto
2215    if isinstance(subdesc.options, descriptor.FieldOptions):
2216        ext_type = nanopb_pb2.nanopb
2217    elif isinstance(subdesc.options, descriptor.FileOptions):
2218        ext_type = nanopb_pb2.nanopb_fileopt
2219    elif isinstance(subdesc.options, descriptor.MessageOptions):
2220        ext_type = nanopb_pb2.nanopb_msgopt
2221    elif isinstance(subdesc.options, descriptor.EnumOptions):
2222        ext_type = nanopb_pb2.nanopb_enumopt
2223    else:
2224        raise Exception("Unknown options type")
2225
2226    if subdesc.options.HasExtension(ext_type):
2227        ext = subdesc.options.Extensions[ext_type]
2228        new_options.MergeFrom(ext)
2229
2230    if Globals.verbose_options:
2231        sys.stderr.write("Options for " + dotname + ": ")
2232        sys.stderr.write(text_format.MessageToString(new_options) + "\n")
2233
2234    return new_options
2235
2236
2237# ---------------------------------------------------------------------------
2238#                         Command line interface
2239# ---------------------------------------------------------------------------
2240
2241import sys
2242import os.path
2243from optparse import OptionParser
2244
2245optparser = OptionParser(
2246    usage = "Usage: nanopb_generator.py [options] file.pb ...",
2247    epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " +
2248             "Output will be written to file.pb.h and file.pb.c.")
2249optparser.add_option("--version", dest="version", action="store_true",
2250    help="Show version info and exit")
2251optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
2252    help="Exclude file from generated #include list.")
2253optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
2254    help="Set extension to use instead of '.pb' for generated files. [default: %default]")
2255optparser.add_option("-H", "--header-extension", dest="header_extension", metavar="EXTENSION", default=".h",
2256    help="Set extension to use for generated header files. [default: %default]")
2257optparser.add_option("-S", "--source-extension", dest="source_extension", metavar="EXTENSION", default=".c",
2258    help="Set extension to use for generated source files. [default: %default]")
2259optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options",
2260    help="Set name of a separate generator options file.")
2261optparser.add_option("-I", "--options-path", "--proto-path", dest="options_path", metavar="DIR",
2262    action="append", default = [],
2263    help="Search path for .options and .proto files. Also determines relative paths for output directory structure.")
2264optparser.add_option("--error-on-unmatched", dest="error_on_unmatched", action="store_true", default=False,
2265                     help ="Stop generation if there are unmatched fields in options file")
2266optparser.add_option("--no-error-on-unmatched", dest="error_on_unmatched", action="store_false", default=False,
2267                     help ="Continue generation if there are unmatched fields in options file (default)")
2268optparser.add_option("-D", "--output-dir", dest="output_dir",
2269                     metavar="OUTPUTDIR", default=None,
2270                     help="Output directory of .pb.h and .pb.c files")
2271optparser.add_option("-Q", "--generated-include-format", dest="genformat",
2272    metavar="FORMAT", default='#include "%s"',
2273    help="Set format string to use for including other .pb.h files. Value can be 'quote', 'bracket' or a format string. [default: %default]")
2274optparser.add_option("-L", "--library-include-format", dest="libformat",
2275    metavar="FORMAT", default='#include <%s>',
2276    help="Set format string to use for including the nanopb pb.h header. Value can be 'quote', 'bracket' or a format string. [default: %default]")
2277optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=False,
2278    help="Strip directory path from #included .pb.h file name")
2279optparser.add_option("--no-strip-path", dest="strip_path", action="store_false",
2280    help="Opposite of --strip-path (default since 0.4.0)")
2281optparser.add_option("--cpp-descriptors", action="store_true",
2282    help="Generate C++ descriptors to lookup by type (e.g. pb_field_t for a message)")
2283optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=True,
2284    help="Don't add timestamp to .pb.h and .pb.c preambles (default since 0.4.0)")
2285optparser.add_option("-t", "--timestamp", dest="notimestamp", action="store_false", default=True,
2286    help="Add timestamp to .pb.h and .pb.c preambles")
2287optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
2288    help="Don't print anything except errors.")
2289optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
2290    help="Print more information.")
2291optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[],
2292    help="Set generator option (max_size, max_count etc.).")
2293optparser.add_option("--protoc-opt", dest="protoc_opts", action="append", default = [], metavar="OPTION",
2294    help="Pass an option to protoc when compiling .proto files")
2295optparser.add_option("--protoc-insertion-points", dest="protoc_insertion_points", action="store_true", default=False,
2296    help="Include insertion point comments in output for use by custom protoc plugins")
2297optparser.add_option("-C", "--c-style", dest="c_style", action="store_true", default=False,
2298    help="Use C naming convention.")
2299
2300def process_cmdline(args, is_plugin):
2301    '''Process command line options. Returns list of options, filenames.'''
2302
2303    options, filenames = optparser.parse_args(args)
2304
2305    if options.version:
2306        if is_plugin:
2307            sys.stderr.write('%s\n' % (nanopb_version))
2308        else:
2309            print(nanopb_version)
2310        sys.exit(0)
2311
2312    if not filenames and not is_plugin:
2313        optparser.print_help()
2314        sys.exit(1)
2315
2316    if options.quiet:
2317        options.verbose = False
2318
2319    include_formats = {'quote': '#include "%s"', 'bracket': '#include <%s>'}
2320    options.libformat = include_formats.get(options.libformat, options.libformat)
2321    options.genformat = include_formats.get(options.genformat, options.genformat)
2322
2323    if options.c_style:
2324        Globals.naming_style = NamingStyleC()
2325
2326    Globals.verbose_options = options.verbose
2327
2328    if options.verbose:
2329        sys.stderr.write("Nanopb version %s\n" % nanopb_version)
2330        sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
2331                         % (google.protobuf.__file__, google.protobuf.__version__))
2332
2333    return options, filenames
2334
2335
2336def parse_file(filename, fdesc, options):
2337    '''Parse a single file. Returns a ProtoFile instance.'''
2338    toplevel_options = nanopb_pb2.NanoPBOptions()
2339    for s in options.settings:
2340        if ':' not in s and '=' in s:
2341            s = s.replace('=', ':')
2342        text_format.Merge(s, toplevel_options)
2343
2344    if not fdesc:
2345        data = open(filename, 'rb').read()
2346        fdesc = descriptor.FileDescriptorSet.FromString(data).file[0]
2347
2348    # Check if there is a separate .options file
2349    had_abspath = False
2350    try:
2351        optfilename = options.options_file % os.path.splitext(filename)[0]
2352    except TypeError:
2353        # No %s specified, use the filename as-is
2354        optfilename = options.options_file
2355        had_abspath = True
2356
2357    paths = ['.'] + options.options_path
2358    for p in paths:
2359        if os.path.isfile(os.path.join(p, optfilename)):
2360            optfilename = os.path.join(p, optfilename)
2361            if options.verbose:
2362                sys.stderr.write('Reading options from ' + optfilename + '\n')
2363            Globals.separate_options = read_options_file(open(optfilename, 'r', encoding = 'utf-8'))
2364            break
2365    else:
2366        # If we are given a full filename and it does not exist, give an error.
2367        # However, don't give error when we automatically look for .options file
2368        # with the same name as .proto.
2369        if options.verbose or had_abspath:
2370            sys.stderr.write('Options file not found: ' + optfilename + '\n')
2371        Globals.separate_options = []
2372
2373    Globals.matched_namemasks = set()
2374    Globals.protoc_insertion_points = options.protoc_insertion_points
2375
2376    # Parse the file
2377    file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
2378    f = ProtoFile(fdesc, file_options)
2379    f.optfilename = optfilename
2380
2381    return f
2382
2383def process_file(filename, fdesc, options, other_files = {}):
2384    '''Process a single file.
2385    filename: The full path to the .proto or .pb source file, as string.
2386    fdesc: The loaded FileDescriptorSet, or None to read from the input file.
2387    options: Command line options as they come from OptionsParser.
2388
2389    Returns a dict:
2390        {'headername': Name of header file,
2391         'headerdata': Data for the .h header file,
2392         'sourcename': Name of the source code file,
2393         'sourcedata': Data for the .c source code file
2394        }
2395    '''
2396    f = parse_file(filename, fdesc, options)
2397
2398    # Check the list of dependencies, and if they are available in other_files,
2399    # add them to be considered for import resolving. Recursively add any files
2400    # imported by the dependencies.
2401    deps = list(f.fdesc.dependency)
2402    while deps:
2403        dep = deps.pop(0)
2404        if dep in other_files:
2405            f.add_dependency(other_files[dep])
2406            deps += list(other_files[dep].fdesc.dependency)
2407
2408    # Decide the file names
2409    noext = os.path.splitext(filename)[0]
2410    headername = noext + options.extension + options.header_extension
2411    sourcename = noext + options.extension + options.source_extension
2412
2413    if options.strip_path:
2414        headerbasename = os.path.basename(headername)
2415    else:
2416        headerbasename = headername
2417
2418    # List of .proto files that should not be included in the C header file
2419    # even if they are mentioned in the source .proto.
2420    excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude + list(f.file_options.exclude)
2421    includes = [d for d in f.fdesc.dependency if d not in excludes]
2422
2423    headerdata = ''.join(f.generate_header(includes, headerbasename, options))
2424    sourcedata = ''.join(f.generate_source(headerbasename, options))
2425
2426    # Check if there were any lines in .options that did not match a member
2427    unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]
2428    if unmatched:
2429        if options.error_on_unmatched:
2430            raise Exception("Following patterns in " + f.optfilename + " did not match any fields: "
2431                            + ', '.join(unmatched));
2432        elif not options.quiet:
2433            sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: "
2434                            + ', '.join(unmatched) + "\n")
2435
2436            if not Globals.verbose_options:
2437                sys.stderr.write("Use  protoc --nanopb-out=-v:.   to see a list of the field names.\n")
2438
2439    return {'headername': headername, 'headerdata': headerdata,
2440            'sourcename': sourcename, 'sourcedata': sourcedata}
2441
2442def main_cli():
2443    '''Main function when invoked directly from the command line.'''
2444
2445    options, filenames = process_cmdline(sys.argv[1:], is_plugin = False)
2446
2447    if options.output_dir and not os.path.exists(options.output_dir):
2448        optparser.print_help()
2449        sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir)
2450        sys.exit(1)
2451
2452    # Load .pb files into memory and compile any .proto files.
2453    include_path = ['-I%s' % p for p in options.options_path]
2454    all_fdescs = {}
2455    out_fdescs = {}
2456    for filename in filenames:
2457        if filename.endswith(".proto"):
2458            with TemporaryDirectory() as tmpdir:
2459                tmpname = os.path.join(tmpdir, os.path.basename(filename) + ".pb")
2460                args = ["protoc"] + include_path
2461                args += options.protoc_opts
2462                args += ['--include_imports', '--include_source_info', '-o' + tmpname, filename]
2463                status = invoke_protoc(args)
2464                if status != 0: sys.exit(status)
2465                data = open(tmpname, 'rb').read()
2466        else:
2467            data = open(filename, 'rb').read()
2468
2469        fdescs = descriptor.FileDescriptorSet.FromString(data).file
2470        last_fdesc = fdescs[-1]
2471
2472        for fdesc in fdescs:
2473          all_fdescs[fdesc.name] = fdesc
2474
2475        out_fdescs[last_fdesc.name] = last_fdesc
2476
2477    # Process any include files first, in order to have them
2478    # available as dependencies
2479    other_files = {}
2480    for fdesc in all_fdescs.values():
2481        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2482
2483    # Then generate the headers / sources
2484    for fdesc in out_fdescs.values():
2485        results = process_file(fdesc.name, fdesc, options, other_files)
2486
2487        base_dir = options.output_dir or ''
2488        to_write = [
2489            (os.path.join(base_dir, results['headername']), results['headerdata']),
2490            (os.path.join(base_dir, results['sourcename']), results['sourcedata']),
2491        ]
2492
2493        if not options.quiet:
2494            paths = " and ".join([x[0] for x in to_write])
2495            sys.stderr.write("Writing to %s\n" % paths)
2496
2497        for path, data in to_write:
2498            dirname = os.path.dirname(path)
2499            if dirname and not os.path.exists(dirname):
2500                os.makedirs(dirname)
2501
2502            with open(path, 'w') as f:
2503                f.write(data)
2504
2505def main_plugin():
2506    '''Main function when invoked as a protoc plugin.'''
2507
2508    import io, sys
2509    if sys.platform == "win32":
2510        import os, msvcrt
2511        # Set stdin and stdout to binary mode
2512        msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
2513        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
2514
2515    data = io.open(sys.stdin.fileno(), "rb").read()
2516
2517    request = plugin_pb2.CodeGeneratorRequest.FromString(data)
2518
2519    try:
2520        # Versions of Python prior to 2.7.3 do not support unicode
2521        # input to shlex.split(). Try to convert to str if possible.
2522        params = str(request.parameter)
2523    except UnicodeEncodeError:
2524        params = request.parameter
2525
2526    if ',' not in params and ' -' in params:
2527        # Nanopb has traditionally supported space as separator in options
2528        args = shlex.split(params)
2529    else:
2530        # Protoc separates options passed to plugins by comma
2531        # This allows also giving --nanopb_opt option multiple times.
2532        lex = shlex.shlex(params)
2533        lex.whitespace_split = True
2534        lex.whitespace = ','
2535        lex.commenters = ''
2536        args = list(lex)
2537
2538    optparser.usage = "protoc --nanopb_out=outdir [--nanopb_opt=option] ['--nanopb_opt=option with spaces'] file.proto"
2539    optparser.epilog = "Output will be written to file.pb.h and file.pb.c."
2540
2541    if '-h' in args or '--help' in args:
2542        # By default optparser prints help to stdout, which doesn't work for
2543        # protoc plugins.
2544        optparser.print_help(sys.stderr)
2545        sys.exit(1)
2546
2547    options, dummy = process_cmdline(args, is_plugin = True)
2548
2549    response = plugin_pb2.CodeGeneratorResponse()
2550
2551    # Google's protoc does not currently indicate the full path of proto files.
2552    # Instead always add the main file path to the search dirs, that works for
2553    # the common case.
2554    import os.path
2555    options.options_path.append(os.path.dirname(request.file_to_generate[0]))
2556
2557    # Process any include files first, in order to have them
2558    # available as dependencies
2559    other_files = {}
2560    for fdesc in request.proto_file:
2561        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2562
2563    for filename in request.file_to_generate:
2564        for fdesc in request.proto_file:
2565            if fdesc.name == filename:
2566                results = process_file(filename, fdesc, options, other_files)
2567
2568                f = response.file.add()
2569                f.name = results['headername']
2570                f.content = results['headerdata']
2571
2572                f = response.file.add()
2573                f.name = results['sourcename']
2574                f.content = results['sourcedata']
2575
2576    if hasattr(plugin_pb2.CodeGeneratorResponse, "FEATURE_PROTO3_OPTIONAL"):
2577        response.supported_features = plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
2578
2579    io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
2580
2581if __name__ == '__main__':
2582    # Check if we are running as a plugin under protoc
2583    if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv:
2584        main_plugin()
2585    else:
2586        main_cli()
2587