1#!/usr/bin/env python3
2# kate: replace-tabs on; indent-width 4;
3
4from __future__ import unicode_literals
5
6'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
7nanopb_version = "nanopb-0.4.6-dev"
8
9import sys
10import re
11import codecs
12import copy
13import itertools
14import tempfile
15import shutil
16import os
17from functools import reduce
18
19try:
20    # Add some dummy imports to keep packaging tools happy.
21    import google, distutils.util # bbfreeze seems to need these
22    import pkg_resources # pyinstaller / protobuf 2.5 seem to need these
23    import proto.nanopb_pb2 as nanopb_pb2 # pyinstaller seems to need this
24    import pkg_resources.py2_warn
25except:
26    # Don't care, we will error out later if it is actually important.
27    pass
28
29try:
30    # Make sure grpc_tools gets included in binary package if it is available
31    import grpc_tools.protoc
32except:
33    pass
34
35try:
36    import google.protobuf.text_format as text_format
37    import google.protobuf.descriptor_pb2 as descriptor
38    import google.protobuf.compiler.plugin_pb2 as plugin_pb2
39    import google.protobuf.reflection as reflection
40    import google.protobuf.descriptor
41except:
42    sys.stderr.write('''
43         *************************************************************
44         *** Could not import the Google protobuf Python libraries ***
45         *** Try installing package 'python3-protobuf' or similar.  ***
46         *************************************************************
47    ''' + '\n')
48    raise
49
50try:
51    from .proto import nanopb_pb2
52    from .proto._utils import invoke_protoc
53except TypeError:
54    sys.stderr.write('''
55         ****************************************************************************
56         *** Got TypeError when importing the protocol definitions for generator. ***
57         *** This usually means that the protoc in your path doesn't match the    ***
58         *** Python protobuf library version.                                     ***
59         ***                                                                      ***
60         *** Please check the output of the following commands:                   ***
61         *** which protoc                                                         ***
62         *** protoc --version                                                     ***
63         *** python3 -c 'import google.protobuf; print(google.protobuf.__file__)'  ***
64         *** If you are not able to find the python protobuf version using the    ***
65         *** above command, use this command.                                     ***
66         *** pip freeze | grep -i protobuf                                        ***
67         ****************************************************************************
68    ''' + '\n')
69    raise
70except (ValueError, SystemError, ImportError):
71    # Probably invoked directly instead of via installed scripts.
72    import proto.nanopb_pb2 as nanopb_pb2
73    from proto._utils import invoke_protoc
74except:
75    sys.stderr.write('''
76         ********************************************************************
77         *** Failed to import the protocol definitions for generator.     ***
78         *** You have to run 'make' in the nanopb/generator/proto folder. ***
79         ********************************************************************
80    ''' + '\n')
81    raise
82
83try:
84    from tempfile import TemporaryDirectory
85except ImportError:
86    class TemporaryDirectory:
87        '''TemporaryDirectory fallback for Python 2'''
88        def __enter__(self):
89            self.dir = tempfile.mkdtemp()
90            return self.dir
91
92        def __exit__(self, *args):
93            shutil.rmtree(self.dir)
94
95# ---------------------------------------------------------------------------
96#                     Generation of single fields
97# ---------------------------------------------------------------------------
98
99import time
100import os.path
101
102# Values are tuple (c type, pb type, encoded size, data_size)
103FieldD = descriptor.FieldDescriptorProto
104datatypes = {
105    FieldD.TYPE_BOOL:       ('bool',     'BOOL',        1,  4),
106    FieldD.TYPE_DOUBLE:     ('double',   'DOUBLE',      8,  8),
107    FieldD.TYPE_FIXED32:    ('uint32_t', 'FIXED32',     4,  4),
108    FieldD.TYPE_FIXED64:    ('uint64_t', 'FIXED64',     8,  8),
109    FieldD.TYPE_FLOAT:      ('float',    'FLOAT',       4,  4),
110    FieldD.TYPE_INT32:      ('int32_t',  'INT32',      10,  4),
111    FieldD.TYPE_INT64:      ('int64_t',  'INT64',      10,  8),
112    FieldD.TYPE_SFIXED32:   ('int32_t',  'SFIXED32',    4,  4),
113    FieldD.TYPE_SFIXED64:   ('int64_t',  'SFIXED64',    8,  8),
114    FieldD.TYPE_SINT32:     ('int32_t',  'SINT32',      5,  4),
115    FieldD.TYPE_SINT64:     ('int64_t',  'SINT64',     10,  8),
116    FieldD.TYPE_UINT32:     ('uint32_t', 'UINT32',      5,  4),
117    FieldD.TYPE_UINT64:     ('uint64_t', 'UINT64',     10,  8),
118
119    # Integer size override options
120    (FieldD.TYPE_INT32,   nanopb_pb2.IS_8):   ('int8_t',   'INT32', 10,  1),
121    (FieldD.TYPE_INT32,  nanopb_pb2.IS_16):   ('int16_t',  'INT32', 10,  2),
122    (FieldD.TYPE_INT32,  nanopb_pb2.IS_32):   ('int32_t',  'INT32', 10,  4),
123    (FieldD.TYPE_INT32,  nanopb_pb2.IS_64):   ('int64_t',  'INT32', 10,  8),
124    (FieldD.TYPE_SINT32,  nanopb_pb2.IS_8):   ('int8_t',  'SINT32',  2,  1),
125    (FieldD.TYPE_SINT32, nanopb_pb2.IS_16):   ('int16_t', 'SINT32',  3,  2),
126    (FieldD.TYPE_SINT32, nanopb_pb2.IS_32):   ('int32_t', 'SINT32',  5,  4),
127    (FieldD.TYPE_SINT32, nanopb_pb2.IS_64):   ('int64_t', 'SINT32', 10,  8),
128    (FieldD.TYPE_UINT32,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT32',  2,  1),
129    (FieldD.TYPE_UINT32, nanopb_pb2.IS_16):   ('uint16_t','UINT32',  3,  2),
130    (FieldD.TYPE_UINT32, nanopb_pb2.IS_32):   ('uint32_t','UINT32',  5,  4),
131    (FieldD.TYPE_UINT32, nanopb_pb2.IS_64):   ('uint64_t','UINT32', 10,  8),
132    (FieldD.TYPE_INT64,   nanopb_pb2.IS_8):   ('int8_t',   'INT64', 10,  1),
133    (FieldD.TYPE_INT64,  nanopb_pb2.IS_16):   ('int16_t',  'INT64', 10,  2),
134    (FieldD.TYPE_INT64,  nanopb_pb2.IS_32):   ('int32_t',  'INT64', 10,  4),
135    (FieldD.TYPE_INT64,  nanopb_pb2.IS_64):   ('int64_t',  'INT64', 10,  8),
136    (FieldD.TYPE_SINT64,  nanopb_pb2.IS_8):   ('int8_t',  'SINT64',  2,  1),
137    (FieldD.TYPE_SINT64, nanopb_pb2.IS_16):   ('int16_t', 'SINT64',  3,  2),
138    (FieldD.TYPE_SINT64, nanopb_pb2.IS_32):   ('int32_t', 'SINT64',  5,  4),
139    (FieldD.TYPE_SINT64, nanopb_pb2.IS_64):   ('int64_t', 'SINT64', 10,  8),
140    (FieldD.TYPE_UINT64,  nanopb_pb2.IS_8):   ('uint8_t', 'UINT64',  2,  1),
141    (FieldD.TYPE_UINT64, nanopb_pb2.IS_16):   ('uint16_t','UINT64',  3,  2),
142    (FieldD.TYPE_UINT64, nanopb_pb2.IS_32):   ('uint32_t','UINT64',  5,  4),
143    (FieldD.TYPE_UINT64, nanopb_pb2.IS_64):   ('uint64_t','UINT64', 10,  8),
144}
145
146class Globals:
147    '''Ugly global variables, should find a good way to pass these.'''
148    verbose_options = False
149    separate_options = []
150    matched_namemasks = set()
151    protoc_insertion_points = False
152
153# String types (for python 2 / python 3 compatibility)
154try:
155    strtypes = (unicode, str)
156    openmode_unicode = 'rU'
157except NameError:
158    strtypes = (str, )
159    openmode_unicode = 'r'
160
161
162class Names:
163    '''Keeps a set of nested names and formats them to C identifier.'''
164    def __init__(self, parts = ()):
165        if isinstance(parts, Names):
166            parts = parts.parts
167        elif isinstance(parts, strtypes):
168            parts = (parts,)
169        self.parts = tuple(parts)
170
171    def __str__(self):
172        return '_'.join(self.parts)
173
174    def __add__(self, other):
175        if isinstance(other, strtypes):
176            return Names(self.parts + (other,))
177        elif isinstance(other, Names):
178            return Names(self.parts + other.parts)
179        elif isinstance(other, tuple):
180            return Names(self.parts + other)
181        else:
182            raise ValueError("Name parts should be of type str")
183
184    def __eq__(self, other):
185        return isinstance(other, Names) and self.parts == other.parts
186
187    def __lt__(self, other):
188        if not isinstance(other, Names):
189            return NotImplemented
190        return str(self) < str(other)
191
192def names_from_type_name(type_name):
193    '''Parse Names() from FieldDescriptorProto type_name'''
194    if type_name[0] != '.':
195        raise NotImplementedError("Lookup of non-absolute type names is not supported")
196    return Names(type_name[1:].split('.'))
197
198def varint_max_size(max_value):
199    '''Returns the maximum number of bytes a varint can take when encoded.'''
200    if max_value < 0:
201        max_value = 2**64 - max_value
202    for i in range(1, 11):
203        if (max_value >> (i * 7)) == 0:
204            return i
205    raise ValueError("Value too large for varint: " + str(max_value))
206
207assert varint_max_size(-1) == 10
208assert varint_max_size(0) == 1
209assert varint_max_size(127) == 1
210assert varint_max_size(128) == 2
211
212class EncodedSize:
213    '''Class used to represent the encoded size of a field or a message.
214    Consists of a combination of symbolic sizes and integer sizes.'''
215    def __init__(self, value = 0, symbols = [], declarations = [], required_defines = []):
216        if isinstance(value, EncodedSize):
217            self.value = value.value
218            self.symbols = value.symbols
219            self.declarations = value.declarations
220            self.required_defines = value.required_defines
221        elif isinstance(value, strtypes + (Names,)):
222            self.symbols = [str(value)]
223            self.value = 0
224            self.declarations = []
225            self.required_defines = [str(value)]
226        else:
227            self.value = value
228            self.symbols = symbols
229            self.declarations = declarations
230            self.required_defines = required_defines
231
232    def __add__(self, other):
233        if isinstance(other, int):
234            return EncodedSize(self.value + other, self.symbols, self.declarations, self.required_defines)
235        elif isinstance(other, strtypes + (Names,)):
236            return EncodedSize(self.value, self.symbols + [str(other)], self.declarations, self.required_defines + [str(other)])
237        elif isinstance(other, EncodedSize):
238            return EncodedSize(self.value + other.value, self.symbols + other.symbols,
239                               self.declarations + other.declarations, self.required_defines + other.required_defines)
240        else:
241            raise ValueError("Cannot add size: " + repr(other))
242
243    def __mul__(self, other):
244        if isinstance(other, int):
245            return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols],
246                               self.declarations, self.required_defines)
247        else:
248            raise ValueError("Cannot multiply size: " + repr(other))
249
250    def __str__(self):
251        if not self.symbols:
252            return str(self.value)
253        else:
254            return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')'
255
256    def get_declarations(self):
257        '''Get any declarations that must appear alongside this encoded size definition,
258        such as helper union {} types.'''
259        return '\n'.join(self.declarations)
260
261    def get_cpp_guard(self, local_defines):
262        '''Get an #if preprocessor statement listing all defines that are required for this definition.'''
263        needed = [x for x in self.required_defines if x not in local_defines]
264        if needed:
265            return '#if ' + ' && '.join(['defined(%s)' % x for x in needed]) + "\n"
266        else:
267            return ''
268
269    def upperlimit(self):
270        if not self.symbols:
271            return self.value
272        else:
273            return 2**32 - 1
274
275
276'''
277Constants regarding path of proto elements in file descriptor.
278They are used to connect proto elements with source code information (comments)
279These values come from:
280    https://github.com/google/protobuf/blob/master/src/google/protobuf/descriptor.proto
281'''
282MESSAGE_PATH = 4
283ENUM_PATH = 5
284FIELD_PATH = 2
285
286
287class ProtoElement(object):
288    def __init__(self, path, index, comments):
289        '''
290        path is a predefined value for each element type in proto file.
291            For example, message == 4, enum == 5, service == 6
292        index is the N-th occurance of the `path` in the proto file.
293            For example, 4-th message in the proto file or 2-nd enum etc ...
294        comments is a dictionary mapping between element path & SourceCodeInfo.Location
295            (contains information about source comments).
296        '''
297        self.path = path
298        self.index = index
299        self.comments = comments
300
301    def element_path(self):
302        '''Get path to proto element.'''
303        return [self.path, self.index]
304
305    def member_path(self, member_index):
306        '''Get path to member of proto element.
307        Example paths:
308        [4, m] - message comments, m: msgIdx in proto from 0
309        [4, m, 2, f] - field comments in message, f: fieldIdx in message from 0
310        [6, s] - service comments, s: svcIdx in proto from 0
311        [6, s, 2, r] - rpc comments in service, r: rpc method def in service from 0
312        '''
313        return self.element_path() + [FIELD_PATH, member_index]
314
315    def get_comments(self, path, leading_indent=True):
316        '''Get leading & trailing comments for enum member based on path.
317
318        path is the proto path of an element or member (ex. [5 0] or [4 1 2 0])
319        leading_indent is a flag that indicates if leading comments should be indented
320        '''
321
322        # Obtain SourceCodeInfo.Location object containing comment
323        # information (based on the member path)
324        comment = self.comments.get(str(path))
325
326        leading_comment = ""
327        trailing_comment = ""
328
329        if not comment:
330            return leading_comment, trailing_comment
331
332        if comment.leading_comments:
333            leading_comment = "    " if leading_indent else ""
334            leading_comment += "/* %s */" % comment.leading_comments.strip()
335
336        if comment.trailing_comments:
337            trailing_comment = "/* %s */" % comment.trailing_comments.strip()
338
339        return leading_comment, trailing_comment
340
341
342class Enum(ProtoElement):
343    def __init__(self, names, desc, enum_options, index, comments):
344        '''
345        desc is EnumDescriptorProto
346        index is the index of this enum element inside the file
347        comments is a dictionary mapping between element path & SourceCodeInfo.Location
348            (contains information about source comments)
349        '''
350        super(Enum, self).__init__(ENUM_PATH, index, comments)
351
352        self.options = enum_options
353        self.names = names
354
355        # by definition, `names` include this enum's name
356        base_name = Names(names.parts[:-1])
357
358        if enum_options.long_names:
359            self.values = [(names + x.name, x.number) for x in desc.value]
360        else:
361            self.values = [(base_name + x.name, x.number) for x in desc.value]
362
363        self.value_longnames = [self.names + x.name for x in desc.value]
364        self.packed = enum_options.packed_enum
365
366    def has_negative(self):
367        for n, v in self.values:
368            if v < 0:
369                return True
370        return False
371
372    def encoded_size(self):
373        return max([varint_max_size(v) for n,v in self.values])
374
375    def __str__(self):
376        enum_path = self.element_path()
377        leading_comment, trailing_comment = self.get_comments(enum_path, leading_indent=False)
378
379        result = ''
380        if leading_comment:
381            result = '%s\n' % leading_comment
382
383        result += 'typedef enum _%s { %s\n' % (self.names, trailing_comment)
384
385        enum_length = len(self.values)
386        enum_values = []
387        for index, (name, value) in enumerate(self.values):
388            member_path = self.member_path(index)
389            leading_comment, trailing_comment = self.get_comments(member_path)
390
391            if leading_comment:
392                enum_values.append(leading_comment)
393
394            comma = ","
395            if index == enum_length - 1:
396                # last enum member should not end with a comma
397                comma = ""
398
399            enum_values.append("    %s = %d%s %s" % (name, value, comma, trailing_comment))
400
401        result += '\n'.join(enum_values)
402        result += '\n}'
403
404        if self.packed:
405            result += ' pb_packed'
406
407        result += ' %s;' % self.names
408        return result
409
410    def auxiliary_defines(self):
411        # sort the enum by value
412        sorted_values = sorted(self.values, key = lambda x: (x[1], x[0]))
413        result  = '#define _%s_MIN %s\n' % (self.names, sorted_values[0][0])
414        result += '#define _%s_MAX %s\n' % (self.names, sorted_values[-1][0])
415        result += '#define _%s_ARRAYSIZE ((%s)(%s+1))\n' % (self.names, self.names, sorted_values[-1][0])
416
417        if not self.options.long_names:
418            # Define the long names always so that enum value references
419            # from other files work properly.
420            for i, x in enumerate(self.values):
421                result += '#define %s %s\n' % (self.value_longnames[i], x[0])
422
423        if self.options.enum_to_string:
424            result += 'const char *%s_name(%s v);\n' % (self.names, self.names)
425
426        return result
427
428    def enum_to_string_definition(self):
429        if not self.options.enum_to_string:
430            return ""
431
432        result = 'const char *%s_name(%s v) {\n' % (self.names, self.names)
433        result += '    switch (v) {\n'
434
435        for ((enumname, _), strname) in zip(self.values, self.value_longnames):
436            # Strip off the leading type name from the string value.
437            strval = str(strname)[len(str(self.names)) + 1:]
438            result += '        case %s: return "%s";\n' % (enumname, strval)
439
440        result += '    }\n'
441        result += '    return "unknown";\n'
442        result += '}\n'
443
444        return result
445
446class FieldMaxSize:
447    def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
448        if isinstance(worst, list):
449            self.worst = max(i for i in worst if i is not None)
450        else:
451            self.worst = worst
452
453        self.worst_field = field_name
454        self.checks = list(checks)
455
456    def extend(self, extend, field_name = None):
457        self.worst = max(self.worst, extend.worst)
458
459        if self.worst == extend.worst:
460            self.worst_field = extend.worst_field
461
462        self.checks.extend(extend.checks)
463
464class Field:
465    macro_x_param = 'X'
466    macro_a_param = 'a'
467
468    def __init__(self, struct_name, desc, field_options):
469        '''desc is FieldDescriptorProto'''
470        self.tag = desc.number
471        self.struct_name = struct_name
472        self.union_name = None
473        self.name = desc.name
474        self.default = None
475        self.max_size = None
476        self.max_count = None
477        self.array_decl = ""
478        self.enc_size = None
479        self.data_item_size = None
480        self.ctype = None
481        self.fixed_count = False
482        self.callback_datatype = field_options.callback_datatype
483        self.math_include_required = False
484        self.sort_by_tag = field_options.sort_by_tag
485
486        if field_options.type == nanopb_pb2.FT_INLINE:
487            # Before nanopb-0.3.8, fixed length bytes arrays were specified
488            # by setting type to FT_INLINE. But to handle pointer typed fields,
489            # it makes sense to have it as a separate option.
490            field_options.type = nanopb_pb2.FT_STATIC
491            field_options.fixed_length = True
492
493        # Parse field options
494        if field_options.HasField("max_size"):
495            self.max_size = field_options.max_size
496
497        self.default_has = field_options.default_has
498
499        if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"):
500            # max_length overrides max_size for strings
501            self.max_size = field_options.max_length + 1
502
503        if field_options.HasField("max_count"):
504            self.max_count = field_options.max_count
505
506        if desc.HasField('default_value'):
507            self.default = desc.default_value
508
509        # Check field rules, i.e. required/optional/repeated.
510        can_be_static = True
511        if desc.label == FieldD.LABEL_REPEATED:
512            self.rules = 'REPEATED'
513            if self.max_count is None:
514                can_be_static = False
515            else:
516                self.array_decl = '[%d]' % self.max_count
517                if field_options.fixed_count:
518                  self.rules = 'FIXARRAY'
519
520        elif field_options.proto3:
521            if desc.type == FieldD.TYPE_MESSAGE and not field_options.proto3_singular_msgs:
522                # In most other protobuf libraries proto3 submessages have
523                # "null" status. For nanopb, that is implemented as has_ field.
524                self.rules = 'OPTIONAL'
525            elif hasattr(desc, "proto3_optional") and desc.proto3_optional:
526                # Protobuf 3.12 introduced optional fields for proto3 syntax
527                self.rules = 'OPTIONAL'
528            else:
529                # Proto3 singular fields (without has_field)
530                self.rules = 'SINGULAR'
531        elif desc.label == FieldD.LABEL_REQUIRED:
532            self.rules = 'REQUIRED'
533        elif desc.label == FieldD.LABEL_OPTIONAL:
534            self.rules = 'OPTIONAL'
535        else:
536            raise NotImplementedError(desc.label)
537
538        # Check if the field can be implemented with static allocation
539        # i.e. whether the data size is known.
540        if desc.type == FieldD.TYPE_STRING and self.max_size is None:
541            can_be_static = False
542
543        if desc.type == FieldD.TYPE_BYTES and self.max_size is None:
544            can_be_static = False
545
546        # Decide how the field data will be allocated
547        if field_options.type == nanopb_pb2.FT_DEFAULT:
548            if can_be_static:
549                field_options.type = nanopb_pb2.FT_STATIC
550            else:
551                field_options.type = nanopb_pb2.FT_CALLBACK
552
553        if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static:
554            raise Exception("Field '%s' is defined as static, but max_size or "
555                            "max_count is not given." % self.name)
556
557        if field_options.fixed_count and self.max_count is None:
558            raise Exception("Field '%s' is defined as fixed count, "
559                            "but max_count is not given." % self.name)
560
561        if field_options.type == nanopb_pb2.FT_STATIC:
562            self.allocation = 'STATIC'
563        elif field_options.type == nanopb_pb2.FT_POINTER:
564            self.allocation = 'POINTER'
565        elif field_options.type == nanopb_pb2.FT_CALLBACK:
566            self.allocation = 'CALLBACK'
567        else:
568            raise NotImplementedError(field_options.type)
569
570        if field_options.HasField("type_override"):
571            desc.type = field_options.type_override
572
573        # Decide the C data type to use in the struct.
574        if desc.type in datatypes:
575            self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[desc.type]
576
577            # Override the field size if user wants to use smaller integers
578            if (desc.type, field_options.int_size) in datatypes:
579                self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[(desc.type, field_options.int_size)]
580        elif desc.type == FieldD.TYPE_ENUM:
581            self.pbtype = 'ENUM'
582            self.data_item_size = 4
583            self.ctype = names_from_type_name(desc.type_name)
584            if self.default is not None:
585                self.default = self.ctype + self.default
586            self.enc_size = None # Needs to be filled in when enum values are known
587        elif desc.type == FieldD.TYPE_STRING:
588            self.pbtype = 'STRING'
589            self.ctype = 'char'
590            if self.allocation == 'STATIC':
591                self.ctype = 'char'
592                self.array_decl += '[%d]' % self.max_size
593                # -1 because of null terminator. Both pb_encode and pb_decode
594                # check the presence of it.
595                self.enc_size = varint_max_size(self.max_size) + self.max_size - 1
596        elif desc.type == FieldD.TYPE_BYTES:
597            if field_options.fixed_length:
598                self.pbtype = 'FIXED_LENGTH_BYTES'
599
600                if self.max_size is None:
601                    raise Exception("Field '%s' is defined as fixed length, "
602                                    "but max_size is not given." % self.name)
603
604                self.enc_size = varint_max_size(self.max_size) + self.max_size
605                self.ctype = 'pb_byte_t'
606                self.array_decl += '[%d]' % self.max_size
607            else:
608                self.pbtype = 'BYTES'
609                self.ctype = 'pb_bytes_array_t'
610                if self.allocation == 'STATIC':
611                    self.ctype = self.struct_name + self.name + 't'
612                    self.enc_size = varint_max_size(self.max_size) + self.max_size
613        elif desc.type == FieldD.TYPE_MESSAGE:
614            self.pbtype = 'MESSAGE'
615            self.ctype = self.submsgname = names_from_type_name(desc.type_name)
616            self.enc_size = None # Needs to be filled in after the message type is available
617            if field_options.submsg_callback and self.allocation == 'STATIC':
618                self.pbtype = 'MSG_W_CB'
619        else:
620            raise NotImplementedError(desc.type)
621
622        if self.default and self.pbtype in ['FLOAT', 'DOUBLE']:
623            if 'inf' in self.default or 'nan' in self.default:
624                self.math_include_required = True
625
626    def __lt__(self, other):
627        return self.tag < other.tag
628
629    def __str__(self):
630        result = ''
631        if self.allocation == 'POINTER':
632            if self.rules == 'REPEATED':
633                if self.pbtype == 'MSG_W_CB':
634                    result += '    pb_callback_t cb_' + self.name + ';\n'
635                result += '    pb_size_t ' + self.name + '_count;\n'
636
637            if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
638                # Use struct definition, so recursive submessages are possible
639                result += '    struct _%s *%s;' % (self.ctype, self.name)
640            elif self.pbtype == 'FIXED_LENGTH_BYTES' or self.rules == 'FIXARRAY':
641                # Pointer to fixed size array
642                result += '    %s (*%s)%s;' % (self.ctype, self.name, self.array_decl)
643            elif self.rules in ['REPEATED', 'FIXARRAY'] and self.pbtype in ['STRING', 'BYTES']:
644                # String/bytes arrays need to be defined as pointers to pointers
645                result += '    %s **%s;' % (self.ctype, self.name)
646            else:
647                result += '    %s *%s;' % (self.ctype, self.name)
648        elif self.allocation == 'CALLBACK':
649            result += '    %s %s;' % (self.callback_datatype, self.name)
650        else:
651            if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']:
652                result += '    pb_callback_t cb_' + self.name + ';\n'
653
654            if self.rules == 'OPTIONAL':
655                result += '    bool has_' + self.name + ';\n'
656            elif self.rules == 'REPEATED':
657                result += '    pb_size_t ' + self.name + '_count;\n'
658            result += '    %s %s%s;' % (self.ctype, self.name, self.array_decl)
659        return result
660
661    def types(self):
662        '''Return definitions for any special types this field might need.'''
663        if self.pbtype == 'BYTES' and self.allocation == 'STATIC':
664            result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype)
665        else:
666            result = ''
667        return result
668
669    def get_dependencies(self):
670        '''Get list of type names used by this field.'''
671        if self.allocation == 'STATIC':
672            return [str(self.ctype)]
673        else:
674            return []
675
676    def get_initializer(self, null_init, inner_init_only = False):
677        '''Return literal expression for this field's default value.
678        null_init: If True, initialize to a 0 value instead of default from .proto
679        inner_init_only: If True, exclude initialization for any count/has fields
680        '''
681
682        inner_init = None
683        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
684            if null_init:
685                inner_init = '%s_init_zero' % self.ctype
686            else:
687                inner_init = '%s_init_default' % self.ctype
688        elif self.default is None or null_init:
689            if self.pbtype == 'STRING':
690                inner_init = '""'
691            elif self.pbtype == 'BYTES':
692                inner_init = '{0, {0}}'
693            elif self.pbtype == 'FIXED_LENGTH_BYTES':
694                inner_init = '{0}'
695            elif self.pbtype in ('ENUM', 'UENUM'):
696                inner_init = '_%s_MIN' % self.ctype
697            else:
698                inner_init = '0'
699        else:
700            if self.pbtype == 'STRING':
701                data = codecs.escape_encode(self.default.encode('utf-8'))[0]
702                inner_init = '"' + data.decode('ascii') + '"'
703            elif self.pbtype == 'BYTES':
704                data = codecs.escape_decode(self.default)[0]
705                data = ["0x%02x" % c for c in bytearray(data)]
706                if len(data) == 0:
707                    inner_init = '{0, {0}}'
708                else:
709                    inner_init = '{%d, {%s}}' % (len(data), ','.join(data))
710            elif self.pbtype == 'FIXED_LENGTH_BYTES':
711                data = codecs.escape_decode(self.default)[0]
712                data = ["0x%02x" % c for c in bytearray(data)]
713                if len(data) == 0:
714                    inner_init = '{0}'
715                else:
716                    inner_init = '{%s}' % ','.join(data)
717            elif self.pbtype in ['FIXED32', 'UINT32']:
718                inner_init = str(self.default) + 'u'
719            elif self.pbtype in ['FIXED64', 'UINT64']:
720                inner_init = str(self.default) + 'ull'
721            elif self.pbtype in ['SFIXED64', 'INT64']:
722                inner_init = str(self.default) + 'll'
723            elif self.pbtype in ['FLOAT', 'DOUBLE']:
724                inner_init = str(self.default)
725                if 'inf' in inner_init:
726                    inner_init = inner_init.replace('inf', 'INFINITY')
727                elif 'nan' in inner_init:
728                    inner_init = inner_init.replace('nan', 'NAN')
729                elif (not '.' in inner_init) and self.pbtype == 'FLOAT':
730                    inner_init += '.0f'
731                elif self.pbtype == 'FLOAT':
732                    inner_init += 'f'
733            else:
734                inner_init = str(self.default)
735
736        if inner_init_only:
737            return inner_init
738
739        outer_init = None
740        if self.allocation == 'STATIC':
741            if self.rules == 'REPEATED':
742                outer_init = '0, {' + ', '.join([inner_init] * self.max_count) + '}'
743            elif self.rules == 'FIXARRAY':
744                outer_init = '{' + ', '.join([inner_init] * self.max_count) + '}'
745            elif self.rules == 'OPTIONAL':
746                if null_init or not self.default_has:
747                    outer_init = 'false, ' + inner_init
748                else:
749                    outer_init = 'true, ' + inner_init
750            else:
751                outer_init = inner_init
752        elif self.allocation == 'POINTER':
753            if self.rules == 'REPEATED':
754                outer_init = '0, NULL'
755            else:
756                outer_init = 'NULL'
757        elif self.allocation == 'CALLBACK':
758            if self.pbtype == 'EXTENSION':
759                outer_init = 'NULL'
760            else:
761                outer_init = '{{NULL}, NULL}'
762
763        if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']:
764            outer_init = '{{NULL}, NULL}, ' + outer_init
765
766        return outer_init
767
768    def tags(self):
769        '''Return the #define for the tag number of this field.'''
770        identifier = '%s_%s_tag' % (self.struct_name, self.name)
771        return '#define %-40s %d\n' % (identifier, self.tag)
772
773    def fieldlist(self):
774        '''Return the FIELDLIST macro entry for this field.
775        Format is: X(a, ATYPE, HTYPE, LTYPE, field_name, tag)
776        '''
777        name = self.name
778
779        if self.rules == "ONEOF":
780          # For oneofs, make a tuple of the union name, union member name,
781          # and the name inside the parent struct.
782          if not self.anonymous:
783            name = '(%s,%s,%s)' % (self.union_name, self.name, self.union_name + '.' + self.name)
784          else:
785            name = '(%s,%s,%s)' % (self.union_name, self.name, self.name)
786
787        return '%s(%s, %-9s %-9s %-9s %-16s %3d)' % (self.macro_x_param,
788                                                     self.macro_a_param,
789                                                     self.allocation + ',',
790                                                     self.rules + ',',
791                                                     self.pbtype + ',',
792                                                     name + ',',
793                                                     self.tag)
794
795    def data_size(self, dependencies):
796        '''Return estimated size of this field in the C struct.
797        This is used to try to automatically pick right descriptor size.
798        If the estimate is wrong, it will result in compile time error and
799        user having to specify descriptor_width option.
800        '''
801        if self.allocation == 'POINTER' or self.pbtype == 'EXTENSION':
802            size = 8
803            alignment = 8
804        elif self.allocation == 'CALLBACK':
805            size = 16
806            alignment = 8
807        elif self.pbtype in ['MESSAGE', 'MSG_W_CB']:
808            alignment = 8
809            if str(self.submsgname) in dependencies:
810                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
811                size = dependencies[str(self.submsgname)].data_size(other_dependencies)
812            else:
813                size = 256 # Message is in other file, this is reasonable guess for most cases
814
815            if self.pbtype == 'MSG_W_CB':
816                size += 16
817        elif self.pbtype in ['STRING', 'FIXED_LENGTH_BYTES']:
818            size = self.max_size
819            alignment = 4
820        elif self.pbtype == 'BYTES':
821            size = self.max_size + 4
822            alignment = 4
823        elif self.data_item_size is not None:
824            size = self.data_item_size
825            alignment = 4
826            if self.data_item_size >= 8:
827                alignment = 8
828        else:
829            raise Exception("Unhandled field type: %s" % self.pbtype)
830
831        if self.rules in ['REPEATED', 'FIXARRAY'] and self.allocation == 'STATIC':
832            size *= self.max_count
833
834        if self.rules not in ('REQUIRED', 'SINGULAR'):
835            size += 4
836
837        if size % alignment != 0:
838            # Estimate how much alignment requirements will increase the size.
839            size += alignment - (size % alignment)
840
841        return size
842
843    def encoded_size(self, dependencies):
844        '''Return the maximum size that this field can take when encoded,
845        including the field tag. If the size cannot be determined, returns
846        None.'''
847
848        if self.allocation != 'STATIC':
849            return None
850
851        if self.pbtype in ['MESSAGE', 'MSG_W_CB']:
852            encsize = None
853            if str(self.submsgname) in dependencies:
854                submsg = dependencies[str(self.submsgname)]
855                other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name))
856                encsize = submsg.encoded_size(other_dependencies)
857
858                my_msg = dependencies.get(str(self.struct_name))
859                external = (not my_msg or submsg.protofile != my_msg.protofile)
860
861                if encsize and encsize.symbols and external:
862                    # Couldn't fully resolve the size of a dependency from
863                    # another file. Instead of including the symbols directly,
864                    # just use the #define SubMessage_size from the header.
865                    encsize = None
866
867                if encsize is not None:
868                    # Include submessage length prefix
869                    encsize += varint_max_size(encsize.upperlimit())
870                elif not external:
871                    # The dependency is from the same file and size cannot be
872                    # determined for it, thus we know it will not be possible
873                    # in runtime either.
874                    return None
875
876            if encsize is None:
877                # Submessage or its size cannot be found.
878                # This can occur if submessage is defined in different
879                # file, and it or its .options could not be found.
880                # Instead of direct numeric value, reference the size that
881                # has been #defined in the other file.
882                encsize = EncodedSize(self.submsgname + 'size')
883
884                # We will have to make a conservative assumption on the length
885                # prefix size, though.
886                encsize += 5
887
888        elif self.pbtype in ['ENUM', 'UENUM']:
889            if str(self.ctype) in dependencies:
890                enumtype = dependencies[str(self.ctype)]
891                encsize = enumtype.encoded_size()
892            else:
893                # Conservative assumption
894                encsize = 10
895
896        elif self.enc_size is None:
897            raise RuntimeError("Could not determine encoded size for %s.%s"
898                               % (self.struct_name, self.name))
899        else:
900            encsize = EncodedSize(self.enc_size)
901
902        encsize += varint_max_size(self.tag << 3) # Tag + wire type
903
904        if self.rules in ['REPEATED', 'FIXARRAY']:
905            # Decoders must be always able to handle unpacked arrays.
906            # Therefore we have to reserve space for it, even though
907            # we emit packed arrays ourselves. For length of 1, packed
908            # arrays are larger however so we need to add allowance
909            # for the length byte.
910            encsize *= self.max_count
911
912            if self.max_count == 1:
913                encsize += 1
914
915        return encsize
916
917    def has_callbacks(self):
918        return self.allocation == 'CALLBACK'
919
920    def requires_custom_field_callback(self):
921        return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t'
922
923class ExtensionRange(Field):
924    def __init__(self, struct_name, range_start, field_options):
925        '''Implements a special pb_extension_t* field in an extensible message
926        structure. The range_start signifies the index at which the extensions
927        start. Not necessarily all tags above this are extensions, it is merely
928        a speed optimization.
929        '''
930        self.tag = range_start
931        self.struct_name = struct_name
932        self.name = 'extensions'
933        self.pbtype = 'EXTENSION'
934        self.rules = 'OPTIONAL'
935        self.allocation = 'CALLBACK'
936        self.ctype = 'pb_extension_t'
937        self.array_decl = ''
938        self.default = None
939        self.max_size = 0
940        self.max_count = 0
941        self.data_item_size = 0
942        self.fixed_count = False
943        self.callback_datatype = 'pb_extension_t*'
944
945    def requires_custom_field_callback(self):
946        return False
947
948    def __str__(self):
949        return '    pb_extension_t *extensions;'
950
951    def types(self):
952        return ''
953
954    def tags(self):
955        return ''
956
957    def encoded_size(self, dependencies):
958        # We exclude extensions from the count, because they cannot be known
959        # until runtime. Other option would be to return None here, but this
960        # way the value remains useful if extensions are not used.
961        return EncodedSize(0)
962
963class ExtensionField(Field):
964    def __init__(self, fullname, desc, field_options):
965        self.fullname = fullname
966        self.extendee_name = names_from_type_name(desc.extendee)
967        Field.__init__(self, self.fullname + "extmsg", desc, field_options)
968
969        if self.rules != 'OPTIONAL':
970            self.skip = True
971        else:
972            self.skip = False
973            self.rules = 'REQUIRED' # We don't really want the has_field for extensions
974            # currently no support for comments for extension fields => provide 0, {}
975            self.msg = Message(self.fullname + "extmsg", None, field_options, 0, {})
976            self.msg.fields.append(self)
977
978    def tags(self):
979        '''Return the #define for the tag number of this field.'''
980        identifier = '%s_tag' % self.fullname
981        return '#define %-40s %d\n' % (identifier, self.tag)
982
983    def extension_decl(self):
984        '''Declaration of the extension type in the .pb.h file'''
985        if self.skip:
986            msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname
987            msg +='   type of extension fields is currently supported. */\n'
988            return msg
989
990        return ('extern const pb_extension_type_t %s; /* field type: %s */\n' %
991            (self.fullname, str(self).strip()))
992
993    def extension_def(self, dependencies):
994        '''Definition of the extension type in the .pb.c file'''
995
996        if self.skip:
997            return ''
998
999        result = "/* Definition for extension field %s */\n" % self.fullname
1000        result += str(self.msg)
1001        result += self.msg.fields_declaration(dependencies)
1002        result += 'pb_byte_t %s_default[] = {0x00};\n' % self.msg.name
1003        result += self.msg.fields_definition(dependencies)
1004        result += 'const pb_extension_type_t %s = {\n' % self.fullname
1005        result += '    NULL,\n'
1006        result += '    NULL,\n'
1007        result += '    &%s_msg\n' % self.msg.name
1008        result += '};\n'
1009        return result
1010
1011
1012# ---------------------------------------------------------------------------
1013#                   Generation of oneofs (unions)
1014# ---------------------------------------------------------------------------
1015
1016class OneOf(Field):
1017    def __init__(self, struct_name, oneof_desc, oneof_options):
1018        self.struct_name = struct_name
1019        self.name = oneof_desc.name
1020        self.ctype = 'union'
1021        self.pbtype = 'oneof'
1022        self.fields = []
1023        self.allocation = 'ONEOF'
1024        self.default = None
1025        self.rules = 'ONEOF'
1026        self.anonymous = oneof_options.anonymous_oneof
1027        self.sort_by_tag = oneof_options.sort_by_tag
1028        self.has_msg_cb = False
1029
1030    def add_field(self, field):
1031        field.union_name = self.name
1032        field.rules = 'ONEOF'
1033        field.anonymous = self.anonymous
1034        self.fields.append(field)
1035
1036        if self.sort_by_tag:
1037            self.fields.sort()
1038
1039        if field.pbtype == 'MSG_W_CB':
1040            self.has_msg_cb = True
1041
1042        # Sort by the lowest tag number inside union
1043        self.tag = min([f.tag for f in self.fields])
1044
1045    def __str__(self):
1046        result = ''
1047        if self.fields:
1048            if self.has_msg_cb:
1049                result += '    pb_callback_t cb_' + self.name + ';\n'
1050
1051            result += '    pb_size_t which_' + self.name + ";\n"
1052            result += '    union {\n'
1053            for f in self.fields:
1054                result += '    ' + str(f).replace('\n', '\n    ') + '\n'
1055            if self.anonymous:
1056                result += '    };'
1057            else:
1058                result += '    } ' + self.name + ';'
1059        return result
1060
1061    def types(self):
1062        return ''.join([f.types() for f in self.fields])
1063
1064    def get_dependencies(self):
1065        deps = []
1066        for f in self.fields:
1067            deps += f.get_dependencies()
1068        return deps
1069
1070    def get_initializer(self, null_init):
1071        if self.has_msg_cb:
1072            return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}'
1073        else:
1074            return '0, {' + self.fields[0].get_initializer(null_init) + '}'
1075
1076    def tags(self):
1077        return ''.join([f.tags() for f in self.fields])
1078
1079    def data_size(self, dependencies):
1080        return max(f.data_size(dependencies) for f in self.fields)
1081
1082    def encoded_size(self, dependencies):
1083        '''Returns the size of the largest oneof field.'''
1084        largest = 0
1085        dynamic_sizes = {}
1086        for f in self.fields:
1087            size = EncodedSize(f.encoded_size(dependencies))
1088            if size is None or size.value is None:
1089                return None
1090            elif size.symbols:
1091                dynamic_sizes[f.tag] = size
1092            elif size.value > largest:
1093                largest = size.value
1094
1095        if not dynamic_sizes:
1096            # Simple case, all sizes were known at generator time
1097            return EncodedSize(largest)
1098
1099        if largest > 0:
1100            # Some sizes were known, some were not
1101            dynamic_sizes[0] = EncodedSize(largest)
1102
1103        # Couldn't find size for submessage at generation time,
1104        # have to rely on macro resolution at compile time.
1105        if len(dynamic_sizes) == 1:
1106            # Only one symbol was needed
1107            return list(dynamic_sizes.values())[0]
1108        else:
1109            # Use sizeof(union{}) construct to find the maximum size of
1110            # submessages.
1111            union_name = "%s_%s_size_union" % (self.struct_name, self.name)
1112            union_def = 'union %s {%s};\n' % (union_name, ' '.join('char f%d[%s];' % (k, s) for k,s in dynamic_sizes.items()))
1113            required_defs = list(itertools.chain.from_iterable(s.required_defines for k,s in dynamic_sizes.items()))
1114            return EncodedSize(0, ['sizeof(union %s)' % union_name], [union_def], required_defs)
1115
1116    def has_callbacks(self):
1117        return bool([f for f in self.fields if f.has_callbacks()])
1118
1119    def requires_custom_field_callback(self):
1120        return bool([f for f in self.fields if f.requires_custom_field_callback()])
1121
1122# ---------------------------------------------------------------------------
1123#                   Generation of messages (structures)
1124# ---------------------------------------------------------------------------
1125
1126
1127class Message(ProtoElement):
1128    def __init__(self, names, desc, message_options, index, comments):
1129        super(Message, self).__init__(MESSAGE_PATH, index, comments)
1130        self.name = names
1131        self.fields = []
1132        self.oneofs = {}
1133        self.desc = desc
1134        self.math_include_required = False
1135        self.packed = message_options.packed_struct
1136        self.descriptorsize = message_options.descriptorsize
1137
1138        if message_options.msgid:
1139            self.msgid = message_options.msgid
1140
1141        if desc is not None:
1142            self.load_fields(desc, message_options)
1143
1144        self.callback_function = message_options.callback_function
1145        if not message_options.HasField('callback_function'):
1146            # Automatically assign a per-message callback if any field has
1147            # a special callback_datatype.
1148            for field in self.fields:
1149                if field.requires_custom_field_callback():
1150                    self.callback_function = "%s_callback" % self.name
1151                    break
1152
1153    def load_fields(self, desc, message_options):
1154        '''Load field list from DescriptorProto'''
1155
1156        no_unions = []
1157
1158        if hasattr(desc, 'oneof_decl'):
1159            for i, f in enumerate(desc.oneof_decl):
1160                oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name)
1161                if oneof_options.no_unions:
1162                    no_unions.append(i) # No union, but add fields normally
1163                elif oneof_options.type == nanopb_pb2.FT_IGNORE:
1164                    pass # No union and skip fields also
1165                else:
1166                    oneof = OneOf(self.name, f, oneof_options)
1167                    self.oneofs[i] = oneof
1168        else:
1169            sys.stderr.write('Note: This Python protobuf library has no OneOf support\n')
1170
1171        for f in desc.field:
1172            field_options = get_nanopb_suboptions(f, message_options, self.name + f.name)
1173            if field_options.type == nanopb_pb2.FT_IGNORE:
1174                continue
1175
1176            if field_options.descriptorsize > self.descriptorsize:
1177                self.descriptorsize = field_options.descriptorsize
1178
1179            field = Field(self.name, f, field_options)
1180            if hasattr(f, 'oneof_index') and f.HasField('oneof_index'):
1181                if hasattr(f, 'proto3_optional') and f.proto3_optional:
1182                    no_unions.append(f.oneof_index)
1183
1184                if f.oneof_index in no_unions:
1185                    self.fields.append(field)
1186                elif f.oneof_index in self.oneofs:
1187                    self.oneofs[f.oneof_index].add_field(field)
1188
1189                    if self.oneofs[f.oneof_index] not in self.fields:
1190                        self.fields.append(self.oneofs[f.oneof_index])
1191            else:
1192                self.fields.append(field)
1193
1194            if field.math_include_required:
1195                self.math_include_required = True
1196
1197        if len(desc.extension_range) > 0:
1198            field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions')
1199            range_start = min([r.start for r in desc.extension_range])
1200            if field_options.type != nanopb_pb2.FT_IGNORE:
1201                self.fields.append(ExtensionRange(self.name, range_start, field_options))
1202
1203        if message_options.sort_by_tag:
1204            self.fields.sort()
1205
1206    def get_dependencies(self):
1207        '''Get list of type names that this structure refers to.'''
1208        deps = []
1209        for f in self.fields:
1210            deps += f.get_dependencies()
1211        return deps
1212
1213    def __str__(self):
1214        message_path = self.element_path()
1215        leading_comment, trailing_comment = self.get_comments(message_path, leading_indent=False)
1216
1217        result = ''
1218        if leading_comment:
1219            result = '%s\n' % leading_comment
1220
1221        result += 'typedef struct _%s { %s\n' % (self.name, trailing_comment)
1222
1223        if not self.fields:
1224            # Empty structs are not allowed in C standard.
1225            # Therefore add a dummy field if an empty message occurs.
1226            result += '    char dummy_field;'
1227
1228        msg_fields = []
1229        for index, field in enumerate(self.fields):
1230            member_path = self.member_path(index)
1231            leading_comment, trailing_comment = self.get_comments(member_path)
1232
1233            if leading_comment:
1234                msg_fields.append(leading_comment)
1235
1236            msg_fields.append("%s %s" % (str(field), trailing_comment))
1237
1238        result += '\n'.join(msg_fields)
1239
1240        if Globals.protoc_insertion_points:
1241            result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name
1242
1243        result += '\n}'
1244
1245        if self.packed:
1246            result += ' pb_packed'
1247
1248        result += ' %s;' % self.name
1249
1250        if self.packed:
1251            result = 'PB_PACKED_STRUCT_START\n' + result
1252            result += '\nPB_PACKED_STRUCT_END'
1253
1254        return result + '\n'
1255
1256    def types(self):
1257        return ''.join([f.types() for f in self.fields])
1258
1259    def get_initializer(self, null_init):
1260        if not self.fields:
1261            return '{0}'
1262
1263        parts = []
1264        for field in self.fields:
1265            parts.append(field.get_initializer(null_init))
1266        return '{' + ', '.join(parts) + '}'
1267
1268    def count_required_fields(self):
1269        '''Returns number of required fields inside this message'''
1270        count = 0
1271        for f in self.fields:
1272            if not isinstance(f, OneOf):
1273                if f.rules == 'REQUIRED':
1274                    count += 1
1275        return count
1276
1277    def all_fields(self):
1278        '''Iterate over all fields in this message, including nested OneOfs.'''
1279        for f in self.fields:
1280            if isinstance(f, OneOf):
1281                for f2 in f.fields:
1282                    yield f2
1283            else:
1284                yield f
1285
1286
1287    def field_for_tag(self, tag):
1288        '''Given a tag number, return the Field instance.'''
1289        for field in self.all_fields():
1290            if field.tag == tag:
1291                return field
1292        return None
1293
1294    def count_all_fields(self):
1295        '''Count the total number of fields in this message.'''
1296        count = 0
1297        for f in self.fields:
1298            if isinstance(f, OneOf):
1299                count += len(f.fields)
1300            else:
1301                count += 1
1302        return count
1303
1304    def fields_declaration(self, dependencies):
1305        '''Return X-macro declaration of all fields in this message.'''
1306        Field.macro_x_param = 'X'
1307        Field.macro_a_param = 'a'
1308        while any(field.name == Field.macro_x_param for field in self.all_fields()):
1309            Field.macro_x_param += '_'
1310        while any(field.name == Field.macro_a_param for field in self.all_fields()):
1311            Field.macro_a_param += '_'
1312
1313        # Field descriptor array must be sorted by tag number, pb_common.c relies on it.
1314        sorted_fields = list(self.all_fields())
1315        sorted_fields.sort(key = lambda x: x.tag)
1316
1317        result = '#define %s_FIELDLIST(%s, %s) \\\n' % (self.name,
1318                                                        Field.macro_x_param,
1319                                                        Field.macro_a_param)
1320        result += ' \\\n'.join(x.fieldlist() for x in sorted_fields)
1321        result += '\n'
1322
1323        has_callbacks = bool([f for f in self.fields if f.has_callbacks()])
1324        if has_callbacks:
1325            if self.callback_function != 'pb_default_field_callback':
1326                result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function
1327            result += "#define %s_CALLBACK %s\n" % (self.name, self.callback_function)
1328        else:
1329            result += "#define %s_CALLBACK NULL\n" % self.name
1330
1331        defval = self.default_value(dependencies)
1332        if defval:
1333            hexcoded = ''.join("\\x%02x" % ord(defval[i:i+1]) for i in range(len(defval)))
1334            result += '#define %s_DEFAULT (const pb_byte_t*)"%s\\x00"\n' % (self.name, hexcoded)
1335        else:
1336            result += '#define %s_DEFAULT NULL\n' % self.name
1337
1338        for field in sorted_fields:
1339            if field.pbtype in ['MESSAGE', 'MSG_W_CB']:
1340                if field.rules == 'ONEOF':
1341                    result += "#define %s_%s_%s_MSGTYPE %s\n" % (self.name, field.union_name, field.name, field.ctype)
1342                else:
1343                    result += "#define %s_%s_MSGTYPE %s\n" % (self.name, field.name, field.ctype)
1344
1345        return result
1346
1347    def fields_declaration_cpp_lookup(self):
1348        result = 'template <>\n'
1349        result += 'struct MessageDescriptor<%s> {\n' % (self.name)
1350        result += '    static PB_INLINE_CONSTEXPR const pb_size_t fields_array_length = %d;\n' % (self.count_all_fields())
1351        result += '    static inline const pb_msgdesc_t* fields() {\n'
1352        result += '        return &%s_msg;\n' % (self.name)
1353        result += '    }\n'
1354        result += '};'
1355        return result
1356
1357    def fields_definition(self, dependencies):
1358        '''Return the field descriptor definition that goes in .pb.c file.'''
1359        width = self.required_descriptor_width(dependencies)
1360        if width == 1:
1361          width = 'AUTO'
1362
1363        result = 'PB_BIND(%s, %s, %s)\n' % (self.name, self.name, width)
1364        return result
1365
1366    def required_descriptor_width(self, dependencies):
1367        '''Estimate how many words are necessary for each field descriptor.'''
1368        if self.descriptorsize != nanopb_pb2.DS_AUTO:
1369            return int(self.descriptorsize)
1370
1371        if not self.fields:
1372          return 1
1373
1374        max_tag = max(field.tag for field in self.all_fields())
1375        max_offset = self.data_size(dependencies)
1376        max_arraysize = max((field.max_count or 0) for field in self.all_fields())
1377        max_datasize = max(field.data_size(dependencies) for field in self.all_fields())
1378
1379        if max_arraysize > 0xFFFF:
1380            return 8
1381        elif (max_tag > 0x3FF or max_offset > 0xFFFF or
1382              max_arraysize > 0x0FFF or max_datasize > 0x0FFF):
1383            return 4
1384        elif max_tag > 0x3F or max_offset > 0xFF:
1385            return 2
1386        else:
1387            # NOTE: Macro logic in pb.h ensures that width 1 will
1388            # be raised to 2 automatically for string/submsg fields
1389            # and repeated fields. Thus only tag and offset need to
1390            # be checked.
1391            return 1
1392
1393    def data_size(self, dependencies):
1394        '''Return approximate sizeof(struct) in the compiled code.'''
1395        return sum(f.data_size(dependencies) for f in self.fields)
1396
1397    def encoded_size(self, dependencies):
1398        '''Return the maximum size that this message can take when encoded.
1399        If the size cannot be determined, returns None.
1400        '''
1401        size = EncodedSize(0)
1402        for field in self.fields:
1403            fsize = field.encoded_size(dependencies)
1404            if fsize is None:
1405                return None
1406            size += fsize
1407
1408        return size
1409
1410    def default_value(self, dependencies):
1411        '''Generate serialized protobuf message that contains the
1412        default values for optional fields.'''
1413
1414        if not self.desc:
1415            return b''
1416
1417        if self.desc.options.map_entry:
1418            return b''
1419
1420        optional_only = copy.deepcopy(self.desc)
1421
1422        # Remove fields without default values
1423        # The iteration is done in reverse order to avoid remove() messing up iteration.
1424        for field in reversed(list(optional_only.field)):
1425            field.ClearField(str('extendee'))
1426            parsed_field = self.field_for_tag(field.number)
1427            if parsed_field is None or parsed_field.allocation != 'STATIC':
1428                optional_only.field.remove(field)
1429            elif (field.label == FieldD.LABEL_REPEATED or
1430                  field.type == FieldD.TYPE_MESSAGE):
1431                optional_only.field.remove(field)
1432            elif hasattr(field, 'oneof_index') and field.HasField('oneof_index'):
1433                optional_only.field.remove(field)
1434            elif field.type == FieldD.TYPE_ENUM:
1435                # The partial descriptor doesn't include the enum type
1436                # so we fake it with int64.
1437                enumname = names_from_type_name(field.type_name)
1438                try:
1439                    enumtype = dependencies[str(enumname)]
1440                except KeyError:
1441                    raise Exception("Could not find enum type %s while generating default values for %s.\n" % (enumname, self.name)
1442                                    + "Try passing all source files to generator at once, or use -I option.")
1443
1444                if field.HasField('default_value'):
1445                    defvals = [v for n,v in enumtype.values if n.parts[-1] == field.default_value]
1446                else:
1447                    # If no default is specified, the default is the first value.
1448                    defvals = [v for n,v in enumtype.values]
1449                if defvals and defvals[0] != 0:
1450                    field.type = FieldD.TYPE_INT64
1451                    field.default_value = str(defvals[0])
1452                    field.ClearField(str('type_name'))
1453                else:
1454                    optional_only.field.remove(field)
1455            elif not field.HasField('default_value'):
1456                optional_only.field.remove(field)
1457
1458        if len(optional_only.field) == 0:
1459            return b''
1460
1461        optional_only.ClearField(str('oneof_decl'))
1462        optional_only.ClearField(str('nested_type'))
1463        optional_only.ClearField(str('extension'))
1464        optional_only.ClearField(str('enum_type'))
1465        desc = google.protobuf.descriptor.MakeDescriptor(optional_only)
1466        msg = reflection.MakeClass(desc)()
1467
1468        for field in optional_only.field:
1469            if field.type == FieldD.TYPE_STRING:
1470                setattr(msg, field.name, field.default_value)
1471            elif field.type == FieldD.TYPE_BYTES:
1472                setattr(msg, field.name, codecs.escape_decode(field.default_value)[0])
1473            elif field.type in [FieldD.TYPE_FLOAT, FieldD.TYPE_DOUBLE]:
1474                setattr(msg, field.name, float(field.default_value))
1475            elif field.type == FieldD.TYPE_BOOL:
1476                setattr(msg, field.name, field.default_value == 'true')
1477            else:
1478                setattr(msg, field.name, int(field.default_value))
1479
1480        return msg.SerializeToString()
1481
1482
1483# ---------------------------------------------------------------------------
1484#                    Processing of entire .proto files
1485# ---------------------------------------------------------------------------
1486
1487def iterate_messages(desc, flatten = False, names = Names()):
1488    '''Recursively find all messages. For each, yield name, DescriptorProto.'''
1489    if hasattr(desc, 'message_type'):
1490        submsgs = desc.message_type
1491    else:
1492        submsgs = desc.nested_type
1493
1494    for submsg in submsgs:
1495        sub_names = names + submsg.name
1496        if flatten:
1497            yield Names(submsg.name), submsg
1498        else:
1499            yield sub_names, submsg
1500
1501        for x in iterate_messages(submsg, flatten, sub_names):
1502            yield x
1503
1504def iterate_extensions(desc, flatten = False, names = Names()):
1505    '''Recursively find all extensions.
1506    For each, yield name, FieldDescriptorProto.
1507    '''
1508    for extension in desc.extension:
1509        yield names, extension
1510
1511    for subname, subdesc in iterate_messages(desc, flatten, names):
1512        for extension in subdesc.extension:
1513            yield subname, extension
1514
1515def toposort2(data):
1516    '''Topological sort.
1517    From http://code.activestate.com/recipes/577413-topological-sort/
1518    This function is under the MIT license.
1519    '''
1520    for k, v in list(data.items()):
1521        v.discard(k) # Ignore self dependencies
1522    extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys())
1523    data.update(dict([(item, set()) for item in extra_items_in_deps]))
1524    while True:
1525        ordered = set(item for item,dep in list(data.items()) if not dep)
1526        if not ordered:
1527            break
1528        for item in sorted(ordered):
1529            yield item
1530        data = dict([(item, (dep - ordered)) for item,dep in list(data.items())
1531                if item not in ordered])
1532    assert not data, "A cyclic dependency exists amongst %r" % data
1533
1534def sort_dependencies(messages):
1535    '''Sort a list of Messages based on dependencies.'''
1536    dependencies = {}
1537    message_by_name = {}
1538    for message in messages:
1539        dependencies[str(message.name)] = set(message.get_dependencies())
1540        message_by_name[str(message.name)] = message
1541
1542    for msgname in toposort2(dependencies):
1543        if msgname in message_by_name:
1544            yield message_by_name[msgname]
1545
1546def make_identifier(headername):
1547    '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9'''
1548    result = ""
1549    for c in headername.upper():
1550        if c.isalnum():
1551            result += c
1552        else:
1553            result += '_'
1554    return result
1555
1556class ProtoFile:
1557    def __init__(self, fdesc, file_options):
1558        '''Takes a FileDescriptorProto and parses it.'''
1559        self.fdesc = fdesc
1560        self.file_options = file_options
1561        self.dependencies = {}
1562        self.math_include_required = False
1563        self.parse()
1564        for message in self.messages:
1565            if message.math_include_required:
1566                self.math_include_required = True
1567                break
1568
1569        # Some of types used in this file probably come from the file itself.
1570        # Thus it has implicit dependency on itself.
1571        self.add_dependency(self)
1572
1573    def parse(self):
1574        self.enums = []
1575        self.messages = []
1576        self.extensions = []
1577
1578        mangle_names = self.file_options.mangle_names
1579        flatten = mangle_names == nanopb_pb2.M_FLATTEN
1580        strip_prefix = None
1581        replacement_prefix = None
1582        if mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1583            strip_prefix = "." + self.fdesc.package
1584        elif mangle_names == nanopb_pb2.M_PACKAGE_INITIALS:
1585            strip_prefix = "." + self.fdesc.package
1586            replacement_prefix = ""
1587            for part in self.fdesc.package.split("."):
1588                replacement_prefix += part[0]
1589        elif self.file_options.package:
1590            strip_prefix = "." + self.fdesc.package
1591            replacement_prefix = self.file_options.package
1592
1593
1594        def create_name(names):
1595            if mangle_names in (nanopb_pb2.M_NONE, nanopb_pb2.M_PACKAGE_INITIALS):
1596                return base_name + names
1597            if mangle_names == nanopb_pb2.M_STRIP_PACKAGE:
1598                return Names(names)
1599            single_name = names
1600            if isinstance(names, Names):
1601                single_name = names.parts[-1]
1602            return Names(single_name)
1603
1604        def mangle_field_typename(typename):
1605            if mangle_names == nanopb_pb2.M_FLATTEN:
1606                return "." + typename.split(".")[-1]
1607            if strip_prefix is not None and typename.startswith(strip_prefix):
1608                if replacement_prefix is not None:
1609                    return "." + replacement_prefix + typename[len(strip_prefix):]
1610                else:
1611                    return typename[len(strip_prefix):]
1612            if self.file_options.package:
1613                return "." + replacement_prefix + typename
1614            return typename
1615
1616        if replacement_prefix is not None:
1617            base_name = Names(replacement_prefix.split('.'))
1618        elif self.fdesc.package:
1619            base_name = Names(self.fdesc.package.split('.'))
1620        else:
1621            base_name = Names()
1622
1623        # process source code comment locations
1624        # ignores any locations that do not contain any comment information
1625        self.comment_locations = {
1626            str(list(location.path)): location
1627            for location in self.fdesc.source_code_info.location
1628            if location.leading_comments or location.leading_detached_comments or location.trailing_comments
1629        }
1630
1631        for index, enum in enumerate(self.fdesc.enum_type):
1632            name = create_name(enum.name)
1633            enum_options = get_nanopb_suboptions(enum, self.file_options, name)
1634            self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations))
1635
1636        for index, (names, message) in enumerate(iterate_messages(self.fdesc, flatten)):
1637            name = create_name(names)
1638            message_options = get_nanopb_suboptions(message, self.file_options, name)
1639
1640            if message_options.skip_message:
1641                continue
1642
1643            message = copy.deepcopy(message)
1644            for field in message.field:
1645                if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1646                    field.type_name = mangle_field_typename(field.type_name)
1647
1648            self.messages.append(Message(name, message, message_options, index, self.comment_locations))
1649            for index, enum in enumerate(message.enum_type):
1650                name = create_name(names + enum.name)
1651                enum_options = get_nanopb_suboptions(enum, message_options, name)
1652                self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations))
1653
1654        for names, extension in iterate_extensions(self.fdesc, flatten):
1655            name = create_name(names + extension.name)
1656            field_options = get_nanopb_suboptions(extension, self.file_options, name)
1657
1658            extension = copy.deepcopy(extension)
1659            if extension.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM):
1660                extension.type_name = mangle_field_typename(extension.type_name)
1661
1662            if field_options.type != nanopb_pb2.FT_IGNORE:
1663                self.extensions.append(ExtensionField(name, extension, field_options))
1664
1665    def add_dependency(self, other):
1666        for enum in other.enums:
1667            self.dependencies[str(enum.names)] = enum
1668            enum.protofile = other
1669
1670        for msg in other.messages:
1671            self.dependencies[str(msg.name)] = msg
1672            msg.protofile = other
1673
1674        # Fix field default values where enum short names are used.
1675        for enum in other.enums:
1676            if not enum.options.long_names:
1677                for message in self.messages:
1678                    for field in message.all_fields():
1679                        if field.default in enum.value_longnames:
1680                            idx = enum.value_longnames.index(field.default)
1681                            field.default = enum.values[idx][0]
1682
1683        # Fix field data types where enums have negative values.
1684        for enum in other.enums:
1685            if not enum.has_negative():
1686                for message in self.messages:
1687                    for field in message.all_fields():
1688                        if field.pbtype == 'ENUM' and field.ctype == enum.names:
1689                            field.pbtype = 'UENUM'
1690
1691    def generate_header(self, includes, headername, options):
1692        '''Generate content for a header file.
1693        Generates strings, which should be concatenated and stored to file.
1694        '''
1695
1696        yield '/* Automatically generated nanopb header */\n'
1697        if options.notimestamp:
1698            yield '/* Generated by %s */\n\n' % (nanopb_version)
1699        else:
1700            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1701
1702        if self.fdesc.package:
1703            symbol = make_identifier(self.fdesc.package + '_' + headername)
1704        else:
1705            symbol = make_identifier(headername)
1706        yield '#ifndef PB_%s_INCLUDED\n' % symbol
1707        yield '#define PB_%s_INCLUDED\n' % symbol
1708        if self.math_include_required:
1709            yield '#include <math.h>\n'
1710        try:
1711            yield options.libformat % ('pb.h')
1712        except TypeError:
1713            # no %s specified - use whatever was passed in as options.libformat
1714            yield options.libformat
1715        yield '\n'
1716
1717        for incfile in self.file_options.include:
1718            # allow including system headers
1719            if (incfile.startswith('<')):
1720                yield '#include %s\n' % incfile
1721            else:
1722                yield options.genformat % incfile
1723                yield '\n'
1724
1725        for incfile in includes:
1726            noext = os.path.splitext(incfile)[0]
1727            yield options.genformat % (noext + options.extension + options.header_extension)
1728            yield '\n'
1729
1730        if Globals.protoc_insertion_points:
1731            yield '/* @@protoc_insertion_point(includes) */\n'
1732
1733        yield '\n'
1734
1735        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
1736        yield '#error Regenerate this file with the current version of nanopb generator.\n'
1737        yield '#endif\n'
1738        yield '\n'
1739
1740        if self.enums:
1741            yield '/* Enum definitions */\n'
1742            for enum in self.enums:
1743                yield str(enum) + '\n\n'
1744
1745        if self.messages:
1746            yield '/* Struct definitions */\n'
1747            for msg in sort_dependencies(self.messages):
1748                yield msg.types()
1749                yield str(msg) + '\n'
1750            yield '\n'
1751
1752        if self.extensions:
1753            yield '/* Extensions */\n'
1754            for extension in self.extensions:
1755                yield extension.extension_decl()
1756            yield '\n'
1757
1758        if self.enums:
1759                yield '/* Helper constants for enums */\n'
1760                for enum in self.enums:
1761                    yield enum.auxiliary_defines() + '\n'
1762                yield '\n'
1763
1764        yield '#ifdef __cplusplus\n'
1765        yield 'extern "C" {\n'
1766        yield '#endif\n\n'
1767
1768        if self.messages:
1769            yield '/* Initializer values for message structs */\n'
1770            for msg in self.messages:
1771                identifier = '%s_init_default' % msg.name
1772                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
1773            for msg in self.messages:
1774                identifier = '%s_init_zero' % msg.name
1775                yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
1776            yield '\n'
1777
1778            yield '/* Field tags (for use in manual encoding/decoding) */\n'
1779            for msg in sort_dependencies(self.messages):
1780                for field in msg.fields:
1781                    yield field.tags()
1782            for extension in self.extensions:
1783                yield extension.tags()
1784            yield '\n'
1785
1786            yield '/* Struct field encoding specification for nanopb */\n'
1787            for msg in self.messages:
1788                yield msg.fields_declaration(self.dependencies) + '\n'
1789            for msg in self.messages:
1790                yield 'extern const pb_msgdesc_t %s_msg;\n' % msg.name
1791            yield '\n'
1792
1793            yield '/* Defines for backwards compatibility with code written before nanopb-0.4.0 */\n'
1794            for msg in self.messages:
1795              yield '#define %s_fields &%s_msg\n' % (msg.name, msg.name)
1796            yield '\n'
1797
1798            yield '/* Maximum encoded size of messages (where known) */\n'
1799            messagesizes = []
1800            for msg in self.messages:
1801                identifier = '%s_size' % msg.name
1802                messagesizes.append((identifier, msg.encoded_size(self.dependencies)))
1803
1804            # If we require a symbol from another file, put a preprocessor if statement
1805            # around it to prevent compilation errors if the symbol is not actually available.
1806            local_defines = [identifier for identifier, msize in messagesizes if msize is not None]
1807            guards = {}
1808            for identifier, msize in messagesizes:
1809                if msize is not None:
1810                    cpp_guard = msize.get_cpp_guard(local_defines)
1811                    if cpp_guard not in guards:
1812                        guards[cpp_guard] = set()
1813                    for decl in msize.get_declarations().splitlines():
1814                        guards[cpp_guard].add(decl)
1815                    guards[cpp_guard].add('#define %-40s %s' % (identifier, msize))
1816                else:
1817                    yield '/* %s depends on runtime parameters */\n' % identifier
1818            for guard, values in guards.items():
1819                if guard:
1820                    yield guard
1821                for v in sorted(values):
1822                    yield v
1823                    yield '\n'
1824                if guard:
1825                    yield '#endif\n'
1826            yield '\n'
1827
1828            if [msg for msg in self.messages if hasattr(msg,'msgid')]:
1829              yield '/* Message IDs (where set with "msgid" option) */\n'
1830              for msg in self.messages:
1831                  if hasattr(msg,'msgid'):
1832                      yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
1833              yield '\n'
1834
1835              symbol = make_identifier(headername.split('.')[0])
1836              yield '#define %s_MESSAGES \\\n' % symbol
1837
1838              for msg in self.messages:
1839                  m = "-1"
1840                  msize = msg.encoded_size(self.dependencies)
1841                  if msize is not None:
1842                      m = msize
1843                  if hasattr(msg,'msgid'):
1844                      yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
1845              yield '\n'
1846
1847              for msg in self.messages:
1848                  if hasattr(msg,'msgid'):
1849                      yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
1850              yield '\n'
1851
1852        yield '#ifdef __cplusplus\n'
1853        yield '} /* extern "C" */\n'
1854        yield '#endif\n'
1855
1856        if options.cpp_descriptors:
1857            yield '\n'
1858            yield '#ifdef __cplusplus\n'
1859            yield '/* Message descriptors for nanopb */\n'
1860            yield 'namespace nanopb {\n'
1861            for msg in self.messages:
1862                yield msg.fields_declaration_cpp_lookup() + '\n'
1863            yield '}  // namespace nanopb\n'
1864            yield '\n'
1865            yield '#endif  /* __cplusplus */\n'
1866            yield '\n'
1867
1868        if Globals.protoc_insertion_points:
1869            yield '/* @@protoc_insertion_point(eof) */\n'
1870
1871        # End of header
1872        yield '\n#endif\n'
1873
1874    def generate_source(self, headername, options):
1875        '''Generate content for a source file.'''
1876
1877        yield '/* Automatically generated nanopb constant definitions */\n'
1878        if options.notimestamp:
1879            yield '/* Generated by %s */\n\n' % (nanopb_version)
1880        else:
1881            yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
1882        yield options.genformat % (headername)
1883        yield '\n'
1884
1885        if Globals.protoc_insertion_points:
1886            yield '/* @@protoc_insertion_point(includes) */\n'
1887
1888        yield '#if PB_PROTO_HEADER_VERSION != 40\n'
1889        yield '#error Regenerate this file with the current version of nanopb generator.\n'
1890        yield '#endif\n'
1891        yield '\n'
1892
1893        for msg in self.messages:
1894            yield msg.fields_definition(self.dependencies) + '\n\n'
1895
1896        for ext in self.extensions:
1897            yield ext.extension_def(self.dependencies) + '\n'
1898
1899        for enum in self.enums:
1900            yield enum.enum_to_string_definition() + '\n'
1901
1902        # Add checks for numeric limits
1903        if self.messages:
1904            largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
1905            largest_count = largest_msg.count_required_fields()
1906            if largest_count > 64:
1907                yield '\n/* Check that missing required fields will be properly detected */\n'
1908                yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
1909                yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
1910                yield '       setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
1911                yield '#endif\n'
1912
1913        # Add check for sizeof(double)
1914        has_double = False
1915        for msg in self.messages:
1916            for field in msg.all_fields():
1917                if field.ctype == 'double':
1918                    has_double = True
1919
1920        if has_double:
1921            yield '\n'
1922            yield '#ifndef PB_CONVERT_DOUBLE_FLOAT\n'
1923            yield '/* On some platforms (such as AVR), double is really float.\n'
1924            yield ' * To be able to encode/decode double on these platforms, you need.\n'
1925            yield ' * to define PB_CONVERT_DOUBLE_FLOAT in pb.h or compiler command line.\n'
1926            yield ' */\n'
1927            yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
1928            yield '#endif\n'
1929
1930        yield '\n'
1931
1932        if Globals.protoc_insertion_points:
1933            yield '/* @@protoc_insertion_point(eof) */\n'
1934
1935# ---------------------------------------------------------------------------
1936#                    Options parsing for the .proto files
1937# ---------------------------------------------------------------------------
1938
1939from fnmatch import fnmatchcase
1940
1941def read_options_file(infile):
1942    '''Parse a separate options file to list:
1943        [(namemask, options), ...]
1944    '''
1945    results = []
1946    data = infile.read()
1947    data = re.sub(r'/\*.*?\*/', '', data, flags = re.MULTILINE)
1948    data = re.sub(r'//.*?$', '', data, flags = re.MULTILINE)
1949    data = re.sub(r'#.*?$', '', data, flags = re.MULTILINE)
1950    for i, line in enumerate(data.split('\n')):
1951        line = line.strip()
1952        if not line:
1953            continue
1954
1955        parts = line.split(None, 1)
1956
1957        if len(parts) < 2:
1958            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1959                             "Option lines should have space between field name and options. " +
1960                             "Skipping line: '%s'\n" % line)
1961            sys.exit(1)
1962
1963        opts = nanopb_pb2.NanoPBOptions()
1964
1965        try:
1966            text_format.Merge(parts[1], opts)
1967        except Exception as e:
1968            sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
1969                             "Unparseable option line: '%s'. " % line +
1970                             "Error: %s\n" % str(e))
1971            sys.exit(1)
1972        results.append((parts[0], opts))
1973
1974    return results
1975
1976def get_nanopb_suboptions(subdesc, options, name):
1977    '''Get copy of options, and merge information from subdesc.'''
1978    new_options = nanopb_pb2.NanoPBOptions()
1979    new_options.CopyFrom(options)
1980
1981    if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3":
1982        new_options.proto3 = True
1983
1984    # Handle options defined in a separate file
1985    dotname = '.'.join(name.parts)
1986    for namemask, options in Globals.separate_options:
1987        if fnmatchcase(dotname, namemask):
1988            Globals.matched_namemasks.add(namemask)
1989            new_options.MergeFrom(options)
1990
1991    # Handle options defined in .proto
1992    if isinstance(subdesc.options, descriptor.FieldOptions):
1993        ext_type = nanopb_pb2.nanopb
1994    elif isinstance(subdesc.options, descriptor.FileOptions):
1995        ext_type = nanopb_pb2.nanopb_fileopt
1996    elif isinstance(subdesc.options, descriptor.MessageOptions):
1997        ext_type = nanopb_pb2.nanopb_msgopt
1998    elif isinstance(subdesc.options, descriptor.EnumOptions):
1999        ext_type = nanopb_pb2.nanopb_enumopt
2000    else:
2001        raise Exception("Unknown options type")
2002
2003    if subdesc.options.HasExtension(ext_type):
2004        ext = subdesc.options.Extensions[ext_type]
2005        new_options.MergeFrom(ext)
2006
2007    if Globals.verbose_options:
2008        sys.stderr.write("Options for " + dotname + ": ")
2009        sys.stderr.write(text_format.MessageToString(new_options) + "\n")
2010
2011    return new_options
2012
2013
2014# ---------------------------------------------------------------------------
2015#                         Command line interface
2016# ---------------------------------------------------------------------------
2017
2018import sys
2019import os.path
2020from optparse import OptionParser
2021
2022optparser = OptionParser(
2023    usage = "Usage: nanopb_generator.py [options] file.pb ...",
2024    epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " +
2025             "Output will be written to file.pb.h and file.pb.c.")
2026optparser.add_option("--version", dest="version", action="store_true",
2027    help="Show version info and exit")
2028optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[],
2029    help="Exclude file from generated #include list.")
2030optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb",
2031    help="Set extension to use instead of '.pb' for generated files. [default: %default]")
2032optparser.add_option("-H", "--header-extension", dest="header_extension", metavar="EXTENSION", default=".h",
2033    help="Set extension to use for generated header files. [default: %default]")
2034optparser.add_option("-S", "--source-extension", dest="source_extension", metavar="EXTENSION", default=".c",
2035    help="Set extension to use for generated source files. [default: %default]")
2036optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options",
2037    help="Set name of a separate generator options file.")
2038optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR",
2039    action="append", default = [],
2040    help="Search for .options files additionally in this path")
2041optparser.add_option("--error-on-unmatched", dest="error_on_unmatched", action="store_true", default=False,
2042                     help ="Stop generation if there are unmatched fields in options file")
2043optparser.add_option("--no-error-on-unmatched", dest="error_on_unmatched", action="store_false", default=False,
2044                     help ="Continue generation if there are unmatched fields in options file (default)")
2045optparser.add_option("-D", "--output-dir", dest="output_dir",
2046                     metavar="OUTPUTDIR", default=None,
2047                     help="Output directory of .pb.h and .pb.c files")
2048optparser.add_option("-Q", "--generated-include-format", dest="genformat",
2049    metavar="FORMAT", default='#include "%s"',
2050    help="Set format string to use for including other .pb.h files. [default: %default]")
2051optparser.add_option("-L", "--library-include-format", dest="libformat",
2052    metavar="FORMAT", default='#include <%s>',
2053    help="Set format string to use for including the nanopb pb.h header. [default: %default]")
2054optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=False,
2055    help="Strip directory path from #included .pb.h file name")
2056optparser.add_option("--no-strip-path", dest="strip_path", action="store_false",
2057    help="Opposite of --strip-path (default since 0.4.0)")
2058optparser.add_option("--cpp-descriptors", action="store_true",
2059    help="Generate C++ descriptors to lookup by type (e.g. pb_field_t for a message)")
2060optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=True,
2061    help="Don't add timestamp to .pb.h and .pb.c preambles (default since 0.4.0)")
2062optparser.add_option("-t", "--timestamp", dest="notimestamp", action="store_false", default=True,
2063    help="Add timestamp to .pb.h and .pb.c preambles")
2064optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
2065    help="Don't print anything except errors.")
2066optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
2067    help="Print more information.")
2068optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[],
2069    help="Set generator option (max_size, max_count etc.).")
2070optparser.add_option("--protoc-insertion-points", dest="protoc_insertion_points", action="store_true", default=False,
2071                     help="Include insertion point comments in output for use by custom protoc plugins")
2072
2073def parse_file(filename, fdesc, options):
2074    '''Parse a single file. Returns a ProtoFile instance.'''
2075    toplevel_options = nanopb_pb2.NanoPBOptions()
2076    for s in options.settings:
2077        text_format.Merge(s, toplevel_options)
2078
2079    if not fdesc:
2080        data = open(filename, 'rb').read()
2081        fdesc = descriptor.FileDescriptorSet.FromString(data).file[0]
2082
2083    # Check if there is a separate .options file
2084    had_abspath = False
2085    try:
2086        optfilename = options.options_file % os.path.splitext(filename)[0]
2087    except TypeError:
2088        # No %s specified, use the filename as-is
2089        optfilename = options.options_file
2090        had_abspath = True
2091
2092    paths = ['.'] + options.options_path
2093    for p in paths:
2094        if os.path.isfile(os.path.join(p, optfilename)):
2095            optfilename = os.path.join(p, optfilename)
2096            if options.verbose:
2097                sys.stderr.write('Reading options from ' + optfilename + '\n')
2098            Globals.separate_options = read_options_file(open(optfilename, openmode_unicode))
2099            break
2100    else:
2101        # If we are given a full filename and it does not exist, give an error.
2102        # However, don't give error when we automatically look for .options file
2103        # with the same name as .proto.
2104        if options.verbose or had_abspath:
2105            sys.stderr.write('Options file not found: ' + optfilename + '\n')
2106        Globals.separate_options = []
2107
2108    Globals.matched_namemasks = set()
2109    Globals.protoc_insertion_points = options.protoc_insertion_points
2110
2111    # Parse the file
2112    file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
2113    f = ProtoFile(fdesc, file_options)
2114    f.optfilename = optfilename
2115
2116    return f
2117
2118def process_file(filename, fdesc, options, other_files = {}):
2119    '''Process a single file.
2120    filename: The full path to the .proto or .pb source file, as string.
2121    fdesc: The loaded FileDescriptorSet, or None to read from the input file.
2122    options: Command line options as they come from OptionsParser.
2123
2124    Returns a dict:
2125        {'headername': Name of header file,
2126         'headerdata': Data for the .h header file,
2127         'sourcename': Name of the source code file,
2128         'sourcedata': Data for the .c source code file
2129        }
2130    '''
2131    f = parse_file(filename, fdesc, options)
2132
2133    # Check the list of dependencies, and if they are available in other_files,
2134    # add them to be considered for import resolving. Recursively add any files
2135    # imported by the dependencies.
2136    deps = list(f.fdesc.dependency)
2137    while deps:
2138        dep = deps.pop(0)
2139        if dep in other_files:
2140            f.add_dependency(other_files[dep])
2141            deps += list(other_files[dep].fdesc.dependency)
2142
2143    # Decide the file names
2144    noext = os.path.splitext(filename)[0]
2145    headername = noext + options.extension + options.header_extension
2146    sourcename = noext + options.extension + options.source_extension
2147
2148    if options.strip_path:
2149        headerbasename = os.path.basename(headername)
2150    else:
2151        headerbasename = headername
2152
2153    # List of .proto files that should not be included in the C header file
2154    # even if they are mentioned in the source .proto.
2155    excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude + list(f.file_options.exclude)
2156    includes = [d for d in f.fdesc.dependency if d not in excludes]
2157
2158    headerdata = ''.join(f.generate_header(includes, headerbasename, options))
2159    sourcedata = ''.join(f.generate_source(headerbasename, options))
2160
2161    # Check if there were any lines in .options that did not match a member
2162    unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]
2163    if unmatched:
2164        if options.error_on_unmatched:
2165            raise Exception("Following patterns in " + f.optfilename + " did not match any fields: "
2166                            + ', '.join(unmatched));
2167        elif not options.quiet:
2168            sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: "
2169                            + ', '.join(unmatched) + "\n")
2170
2171        if not Globals.verbose_options:
2172            sys.stderr.write("Use  protoc --nanopb-out=-v:.   to see a list of the field names.\n")
2173
2174    return {'headername': headername, 'headerdata': headerdata,
2175            'sourcename': sourcename, 'sourcedata': sourcedata}
2176
2177def main_cli():
2178    '''Main function when invoked directly from the command line.'''
2179
2180    options, filenames = optparser.parse_args()
2181
2182    if options.version:
2183        print(nanopb_version)
2184        sys.exit(0)
2185
2186    if not filenames:
2187        optparser.print_help()
2188        sys.exit(1)
2189
2190    if options.quiet:
2191        options.verbose = False
2192
2193    if options.output_dir and not os.path.exists(options.output_dir):
2194        optparser.print_help()
2195        sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir)
2196        sys.exit(1)
2197
2198    if options.verbose:
2199        sys.stderr.write("Nanopb version %s\n" % nanopb_version)
2200        sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
2201                         % (google.protobuf.__file__, google.protobuf.__version__))
2202
2203    # Load .pb files into memory and compile any .proto files.
2204    fdescs = {}
2205    include_path = ['-I%s' % p for p in options.options_path]
2206    for filename in filenames:
2207        if filename.endswith(".proto"):
2208            with TemporaryDirectory() as tmpdir:
2209                tmpname = os.path.join(tmpdir, os.path.basename(filename) + ".pb")
2210                status = invoke_protoc(["protoc"] + include_path + ['--include_imports', '--include_source_info', '-o' + tmpname, filename])
2211                if status != 0: sys.exit(status)
2212                data = open(tmpname, 'rb').read()
2213        else:
2214            data = open(filename, 'rb').read()
2215
2216        fdesc = descriptor.FileDescriptorSet.FromString(data).file[-1]
2217        fdescs[fdesc.name] = fdesc
2218
2219    # Process any include files first, in order to have them
2220    # available as dependencies
2221    other_files = {}
2222    for fdesc in fdescs.values():
2223        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2224
2225    # Then generate the headers / sources
2226    Globals.verbose_options = options.verbose
2227    for fdesc in fdescs.values():
2228        results = process_file(fdesc.name, fdesc, options, other_files)
2229
2230        base_dir = options.output_dir or ''
2231        to_write = [
2232            (os.path.join(base_dir, results['headername']), results['headerdata']),
2233            (os.path.join(base_dir, results['sourcename']), results['sourcedata']),
2234        ]
2235
2236        if not options.quiet:
2237            paths = " and ".join([x[0] for x in to_write])
2238            sys.stderr.write("Writing to %s\n" % paths)
2239
2240        for path, data in to_write:
2241            dirname = os.path.dirname(path)
2242            if dirname and not os.path.exists(dirname):
2243                os.makedirs(dirname)
2244
2245            with open(path, 'w') as f:
2246                f.write(data)
2247
2248def main_plugin():
2249    '''Main function when invoked as a protoc plugin.'''
2250
2251    import io, sys
2252    if sys.platform == "win32":
2253        import os, msvcrt
2254        # Set stdin and stdout to binary mode
2255        msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
2256        msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
2257
2258    data = io.open(sys.stdin.fileno(), "rb").read()
2259
2260    request = plugin_pb2.CodeGeneratorRequest.FromString(data)
2261
2262    try:
2263        # Versions of Python prior to 2.7.3 do not support unicode
2264        # input to shlex.split(). Try to convert to str if possible.
2265        params = str(request.parameter)
2266    except UnicodeEncodeError:
2267        params = request.parameter
2268
2269    import shlex
2270    args = shlex.split(params)
2271
2272    if len(args) == 1 and ',' in args[0]:
2273        # For compatibility with other protoc plugins, support options
2274        # separated by comma.
2275        lex = shlex.shlex(params)
2276        lex.whitespace_split = True
2277        lex.whitespace = ','
2278        lex.commenters = ''
2279        args = list(lex)
2280
2281    optparser.usage = "Usage: protoc --nanopb_out=[options][,more_options]:outdir file.proto"
2282    optparser.epilog = "Output will be written to file.pb.h and file.pb.c."
2283
2284    if '-h' in args or '--help' in args:
2285        # By default optparser prints help to stdout, which doesn't work for
2286        # protoc plugins.
2287        optparser.print_help(sys.stderr)
2288        sys.exit(1)
2289
2290    options, dummy = optparser.parse_args(args)
2291
2292    if options.version:
2293        sys.stderr.write('%s\n' % (nanopb_version))
2294        sys.exit(0)
2295
2296    Globals.verbose_options = options.verbose
2297
2298    if options.verbose:
2299        sys.stderr.write("Nanopb version %s\n" % nanopb_version)
2300        sys.stderr.write('Google Python protobuf library imported from %s, version %s\n'
2301                         % (google.protobuf.__file__, google.protobuf.__version__))
2302
2303    response = plugin_pb2.CodeGeneratorResponse()
2304
2305    # Google's protoc does not currently indicate the full path of proto files.
2306    # Instead always add the main file path to the search dirs, that works for
2307    # the common case.
2308    import os.path
2309    options.options_path.append(os.path.dirname(request.file_to_generate[0]))
2310
2311    # Process any include files first, in order to have them
2312    # available as dependencies
2313    other_files = {}
2314    for fdesc in request.proto_file:
2315        other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options)
2316
2317    for filename in request.file_to_generate:
2318        for fdesc in request.proto_file:
2319            if fdesc.name == filename:
2320                results = process_file(filename, fdesc, options, other_files)
2321
2322                f = response.file.add()
2323                f.name = results['headername']
2324                f.content = results['headerdata']
2325
2326                f = response.file.add()
2327                f.name = results['sourcename']
2328                f.content = results['sourcedata']
2329
2330    if hasattr(plugin_pb2.CodeGeneratorResponse, "FEATURE_PROTO3_OPTIONAL"):
2331        response.supported_features = plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL
2332
2333    io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
2334
2335if __name__ == '__main__':
2336    # Check if we are running as a plugin under protoc
2337    if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv:
2338        main_plugin()
2339    else:
2340        main_cli()
2341