1#!/usr/bin/env python3 2# kate: replace-tabs on; indent-width 4; 3 4from __future__ import unicode_literals 5 6'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.''' 7nanopb_version = "nanopb-0.4.7" 8 9import sys 10import re 11import codecs 12import copy 13import itertools 14import tempfile 15import shutil 16import shlex 17import os 18from functools import reduce 19 20# Python-protobuf breaks easily with protoc version differences if 21# using the cpp or upb implementation. Force it to use pure Python 22# implementation. Performance is not very important in the generator. 23if not os.getenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"): 24 os.putenv("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python") 25 os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 26 27try: 28 # Make sure grpc_tools gets included in binary package if it is available 29 import grpc_tools.protoc 30except: 31 pass 32 33try: 34 import google.protobuf.text_format as text_format 35 import google.protobuf.descriptor_pb2 as descriptor 36 import google.protobuf.compiler.plugin_pb2 as plugin_pb2 37 import google.protobuf.reflection as reflection 38 import google.protobuf.descriptor 39except: 40 sys.stderr.write(''' 41 ********************************************************************** 42 *** Could not import the Google protobuf Python libraries *** 43 *** *** 44 *** Easiest solution is often to install the dependencies via pip: *** 45 *** pip install protobuf grpcio-tools *** 46 ********************************************************************** 47 ''' + '\n') 48 raise 49 50# Depending on how this script is run, we may or may not have PEP366 package name 51# available for relative imports. 52if not __package__: 53 import proto 54 from proto._utils import invoke_protoc 55 from proto import TemporaryDirectory 56else: 57 from . import proto 58 from .proto._utils import invoke_protoc 59 from .proto import TemporaryDirectory 60 61if getattr(sys, 'frozen', False): 62 # Binary package, just import the file 63 from proto import nanopb_pb2 64else: 65 # Try to rebuild nanopb_pb2.py if necessary 66 nanopb_pb2 = proto.load_nanopb_pb2() 67 68try: 69 # Add some dummy imports to keep packaging tools happy. 70 import google # bbfreeze seems to need these 71 import pkg_resources # pyinstaller / protobuf 2.5 seem to need these 72 from proto import nanopb_pb2 # pyinstaller seems to need this 73 import pkg_resources.py2_warn 74except: 75 # Don't care, we will error out later if it is actually important. 76 pass 77 78# --------------------------------------------------------------------------- 79# Generation of single fields 80# --------------------------------------------------------------------------- 81 82import time 83import os.path 84 85# Values are tuple (c type, pb type, encoded size, data_size) 86FieldD = descriptor.FieldDescriptorProto 87datatypes = { 88 FieldD.TYPE_BOOL: ('bool', 'BOOL', 1, 4), 89 FieldD.TYPE_DOUBLE: ('double', 'DOUBLE', 8, 8), 90 FieldD.TYPE_FIXED32: ('uint32_t', 'FIXED32', 4, 4), 91 FieldD.TYPE_FIXED64: ('uint64_t', 'FIXED64', 8, 8), 92 FieldD.TYPE_FLOAT: ('float', 'FLOAT', 4, 4), 93 FieldD.TYPE_INT32: ('int32_t', 'INT32', 10, 4), 94 FieldD.TYPE_INT64: ('int64_t', 'INT64', 10, 8), 95 FieldD.TYPE_SFIXED32: ('int32_t', 'SFIXED32', 4, 4), 96 FieldD.TYPE_SFIXED64: ('int64_t', 'SFIXED64', 8, 8), 97 FieldD.TYPE_SINT32: ('int32_t', 'SINT32', 5, 4), 98 FieldD.TYPE_SINT64: ('int64_t', 'SINT64', 10, 8), 99 FieldD.TYPE_UINT32: ('uint32_t', 'UINT32', 5, 4), 100 FieldD.TYPE_UINT64: ('uint64_t', 'UINT64', 10, 8), 101 102 # Integer size override options 103 (FieldD.TYPE_INT32, nanopb_pb2.IS_8): ('int8_t', 'INT32', 10, 1), 104 (FieldD.TYPE_INT32, nanopb_pb2.IS_16): ('int16_t', 'INT32', 10, 2), 105 (FieldD.TYPE_INT32, nanopb_pb2.IS_32): ('int32_t', 'INT32', 10, 4), 106 (FieldD.TYPE_INT32, nanopb_pb2.IS_64): ('int64_t', 'INT32', 10, 8), 107 (FieldD.TYPE_SINT32, nanopb_pb2.IS_8): ('int8_t', 'SINT32', 2, 1), 108 (FieldD.TYPE_SINT32, nanopb_pb2.IS_16): ('int16_t', 'SINT32', 3, 2), 109 (FieldD.TYPE_SINT32, nanopb_pb2.IS_32): ('int32_t', 'SINT32', 5, 4), 110 (FieldD.TYPE_SINT32, nanopb_pb2.IS_64): ('int64_t', 'SINT32', 10, 8), 111 (FieldD.TYPE_UINT32, nanopb_pb2.IS_8): ('uint8_t', 'UINT32', 2, 1), 112 (FieldD.TYPE_UINT32, nanopb_pb2.IS_16): ('uint16_t','UINT32', 3, 2), 113 (FieldD.TYPE_UINT32, nanopb_pb2.IS_32): ('uint32_t','UINT32', 5, 4), 114 (FieldD.TYPE_UINT32, nanopb_pb2.IS_64): ('uint64_t','UINT32', 10, 8), 115 (FieldD.TYPE_INT64, nanopb_pb2.IS_8): ('int8_t', 'INT64', 10, 1), 116 (FieldD.TYPE_INT64, nanopb_pb2.IS_16): ('int16_t', 'INT64', 10, 2), 117 (FieldD.TYPE_INT64, nanopb_pb2.IS_32): ('int32_t', 'INT64', 10, 4), 118 (FieldD.TYPE_INT64, nanopb_pb2.IS_64): ('int64_t', 'INT64', 10, 8), 119 (FieldD.TYPE_SINT64, nanopb_pb2.IS_8): ('int8_t', 'SINT64', 2, 1), 120 (FieldD.TYPE_SINT64, nanopb_pb2.IS_16): ('int16_t', 'SINT64', 3, 2), 121 (FieldD.TYPE_SINT64, nanopb_pb2.IS_32): ('int32_t', 'SINT64', 5, 4), 122 (FieldD.TYPE_SINT64, nanopb_pb2.IS_64): ('int64_t', 'SINT64', 10, 8), 123 (FieldD.TYPE_UINT64, nanopb_pb2.IS_8): ('uint8_t', 'UINT64', 2, 1), 124 (FieldD.TYPE_UINT64, nanopb_pb2.IS_16): ('uint16_t','UINT64', 3, 2), 125 (FieldD.TYPE_UINT64, nanopb_pb2.IS_32): ('uint32_t','UINT64', 5, 4), 126 (FieldD.TYPE_UINT64, nanopb_pb2.IS_64): ('uint64_t','UINT64', 10, 8), 127} 128 129class NamingStyle: 130 def enum_name(self, name): 131 return "_%s" % (name) 132 133 def struct_name(self, name): 134 return "_%s" % (name) 135 136 def type_name(self, name): 137 return "%s" % (name) 138 139 def define_name(self, name): 140 return "%s" % (name) 141 142 def var_name(self, name): 143 return "%s" % (name) 144 145 def enum_entry(self, name): 146 return "%s" % (name) 147 148 def func_name(self, name): 149 return "%s" % (name) 150 151 def bytes_type(self, struct_name, name): 152 return "%s_%s_t" % (struct_name, name) 153 154class NamingStyleC(NamingStyle): 155 def enum_name(self, name): 156 return self.underscore(name) 157 158 def struct_name(self, name): 159 return self.underscore(name) 160 161 def type_name(self, name): 162 return "%s_t" % self.underscore(name) 163 164 def define_name(self, name): 165 return self.underscore(name).upper() 166 167 def var_name(self, name): 168 return self.underscore(name) 169 170 def enum_entry(self, name): 171 return self.underscore(name).upper() 172 173 def func_name(self, name): 174 return self.underscore(name) 175 176 def bytes_type(self, struct_name, name): 177 return "%s_%s_t" % (self.underscore(struct_name), self.underscore(name)) 178 179 def underscore(self, word): 180 word = str(word) 181 word = re.sub(r"([A-Z]+)([A-Z][a-z])", r'\1_\2', word) 182 word = re.sub(r"([a-z\d])([A-Z])", r'\1_\2', word) 183 word = word.replace("-", "_") 184 return word.lower() 185 186class Globals: 187 '''Ugly global variables, should find a good way to pass these.''' 188 verbose_options = False 189 separate_options = [] 190 matched_namemasks = set() 191 protoc_insertion_points = False 192 naming_style = NamingStyle() 193 194# String types and file encoding for Python2 UTF-8 support 195if sys.version_info.major == 2: 196 import codecs 197 open = codecs.open 198 strtypes = (unicode, str) 199 200 def str(x): 201 try: 202 return strtypes[1](x) 203 except UnicodeEncodeError: 204 return strtypes[0](x) 205else: 206 strtypes = (str, ) 207 208 209class Names: 210 '''Keeps a set of nested names and formats them to C identifier.''' 211 def __init__(self, parts = ()): 212 if isinstance(parts, Names): 213 parts = parts.parts 214 elif isinstance(parts, strtypes): 215 parts = (parts,) 216 self.parts = tuple(parts) 217 218 if self.parts == ('',): 219 self.parts = () 220 221 def __str__(self): 222 return '_'.join(self.parts) 223 224 def __repr__(self): 225 return 'Names(%s)' % ','.join("'%s'" % x for x in self.parts) 226 227 def __add__(self, other): 228 if isinstance(other, strtypes): 229 return Names(self.parts + (other,)) 230 elif isinstance(other, Names): 231 return Names(self.parts + other.parts) 232 elif isinstance(other, tuple): 233 return Names(self.parts + other) 234 else: 235 raise ValueError("Name parts should be of type str") 236 237 def __eq__(self, other): 238 return isinstance(other, Names) and self.parts == other.parts 239 240 def __lt__(self, other): 241 if not isinstance(other, Names): 242 return NotImplemented 243 return str(self) < str(other) 244 245def names_from_type_name(type_name): 246 '''Parse Names() from FieldDescriptorProto type_name''' 247 if type_name[0] != '.': 248 raise NotImplementedError("Lookup of non-absolute type names is not supported") 249 return Names(type_name[1:].split('.')) 250 251def varint_max_size(max_value): 252 '''Returns the maximum number of bytes a varint can take when encoded.''' 253 if max_value < 0: 254 max_value = 2**64 - max_value 255 for i in range(1, 11): 256 if (max_value >> (i * 7)) == 0: 257 return i 258 raise ValueError("Value too large for varint: " + str(max_value)) 259 260assert varint_max_size(-1) == 10 261assert varint_max_size(0) == 1 262assert varint_max_size(127) == 1 263assert varint_max_size(128) == 2 264 265class EncodedSize: 266 '''Class used to represent the encoded size of a field or a message. 267 Consists of a combination of symbolic sizes and integer sizes.''' 268 def __init__(self, value = 0, symbols = [], declarations = [], required_defines = []): 269 if isinstance(value, EncodedSize): 270 self.value = value.value 271 self.symbols = value.symbols 272 self.declarations = value.declarations 273 self.required_defines = value.required_defines 274 elif isinstance(value, strtypes + (Names,)): 275 self.symbols = [str(value)] 276 self.value = 0 277 self.declarations = [] 278 self.required_defines = [str(value)] 279 else: 280 self.value = value 281 self.symbols = symbols 282 self.declarations = declarations 283 self.required_defines = required_defines 284 285 def __add__(self, other): 286 if isinstance(other, int): 287 return EncodedSize(self.value + other, self.symbols, self.declarations, self.required_defines) 288 elif isinstance(other, strtypes + (Names,)): 289 return EncodedSize(self.value, self.symbols + [str(other)], self.declarations, self.required_defines + [str(other)]) 290 elif isinstance(other, EncodedSize): 291 return EncodedSize(self.value + other.value, self.symbols + other.symbols, 292 self.declarations + other.declarations, self.required_defines + other.required_defines) 293 else: 294 raise ValueError("Cannot add size: " + repr(other)) 295 296 def __mul__(self, other): 297 if isinstance(other, int): 298 return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols], 299 self.declarations, self.required_defines) 300 else: 301 raise ValueError("Cannot multiply size: " + repr(other)) 302 303 def __str__(self): 304 if not self.symbols: 305 return str(self.value) 306 else: 307 return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')' 308 309 def __repr__(self): 310 return 'EncodedSize(%s, %s, %s, %s)' % (self.value, self.symbols, self.declarations, self.required_defines) 311 312 def get_declarations(self): 313 '''Get any declarations that must appear alongside this encoded size definition, 314 such as helper union {} types.''' 315 return '\n'.join(self.declarations) 316 317 def get_cpp_guard(self, local_defines): 318 '''Get an #if preprocessor statement listing all defines that are required for this definition.''' 319 needed = [x for x in self.required_defines if x not in local_defines] 320 if needed: 321 return '#if ' + ' && '.join(['defined(%s)' % x for x in needed]) + "\n" 322 else: 323 return '' 324 325 def upperlimit(self): 326 if not self.symbols: 327 return self.value 328 else: 329 return 2**32 - 1 330 331class ProtoElement(object): 332 # Constants regarding path of proto elements in file descriptor. 333 # They are used to connect proto elements with source code information (comments) 334 # These values come from: 335 # https://github.com/google/protobuf/blob/master/src/google/protobuf/descriptor.proto 336 FIELD = 2 337 MESSAGE = 4 338 ENUM = 5 339 NESTED_TYPE = 3 340 NESTED_ENUM = 4 341 342 def __init__(self, path, comments = None): 343 ''' 344 path is a tuple containing integers (type, index, ...) 345 comments is a dictionary mapping between element path & SourceCodeInfo.Location 346 (contains information about source comments). 347 ''' 348 assert(isinstance(path, tuple)) 349 self.element_path = path 350 self.comments = comments or {} 351 352 def get_member_comments(self, index): 353 '''Get comments for a member of enum or message.''' 354 return self.get_comments((ProtoElement.FIELD, index), leading_indent = True) 355 356 def format_comment(self, comment): 357 '''Put comment inside /* */ and sanitize comment contents''' 358 comment = comment.strip() 359 comment = comment.replace('/*', '/ *') 360 comment = comment.replace('*/', '* /') 361 return "/* %s */" % comment 362 363 def get_comments(self, member_path = (), leading_indent = False): 364 '''Get leading & trailing comments for a protobuf element. 365 366 member_path is the proto path of an element or member (ex. [5 0] or [4 1 2 0]) 367 leading_indent is a flag that indicates if leading comments should be indented 368 ''' 369 370 # Obtain SourceCodeInfo.Location object containing comment 371 # information (based on the member path) 372 path = self.element_path + member_path 373 comment = self.comments.get(path) 374 375 leading_comment = "" 376 trailing_comment = "" 377 378 if not comment: 379 return leading_comment, trailing_comment 380 381 if comment.leading_comments: 382 leading_comment = " " if leading_indent else "" 383 leading_comment += self.format_comment(comment.leading_comments) 384 385 if comment.trailing_comments: 386 trailing_comment = self.format_comment(comment.trailing_comments) 387 388 return leading_comment, trailing_comment 389 390 391class Enum(ProtoElement): 392 def __init__(self, names, desc, enum_options, element_path, comments): 393 ''' 394 desc is EnumDescriptorProto 395 index is the index of this enum element inside the file 396 comments is a dictionary mapping between element path & SourceCodeInfo.Location 397 (contains information about source comments) 398 ''' 399 super(Enum, self).__init__(element_path, comments) 400 401 self.options = enum_options 402 self.names = names 403 404 # by definition, `names` include this enum's name 405 base_name = Names(names.parts[:-1]) 406 407 if enum_options.long_names: 408 self.values = [(names + x.name, x.number) for x in desc.value] 409 else: 410 self.values = [(base_name + x.name, x.number) for x in desc.value] 411 412 self.value_longnames = [self.names + x.name for x in desc.value] 413 self.packed = enum_options.packed_enum 414 415 def has_negative(self): 416 for n, v in self.values: 417 if v < 0: 418 return True 419 return False 420 421 def encoded_size(self): 422 return max([varint_max_size(v) for n,v in self.values]) 423 424 def __repr__(self): 425 return 'Enum(%s)' % self.names 426 427 def __str__(self): 428 leading_comment, trailing_comment = self.get_comments() 429 430 result = '' 431 if leading_comment: 432 result = '%s\n' % leading_comment 433 434 result += 'typedef enum %s {' % Globals.naming_style.enum_name(self.names) 435 if trailing_comment: 436 result += " " + trailing_comment 437 438 result += "\n" 439 440 enum_length = len(self.values) 441 enum_values = [] 442 for index, (name, value) in enumerate(self.values): 443 leading_comment, trailing_comment = self.get_member_comments(index) 444 445 if leading_comment: 446 enum_values.append(leading_comment) 447 448 comma = "," 449 if index == enum_length - 1: 450 # last enum member should not end with a comma 451 comma = "" 452 453 enum_value = " %s = %d%s" % (Globals.naming_style.enum_entry(name), value, comma) 454 if trailing_comment: 455 enum_value += " " + trailing_comment 456 457 enum_values.append(enum_value) 458 459 result += '\n'.join(enum_values) 460 result += '\n}' 461 462 if self.packed: 463 result += ' pb_packed' 464 465 result += ' %s;' % Globals.naming_style.type_name(self.names) 466 return result 467 468 def auxiliary_defines(self): 469 # sort the enum by value 470 sorted_values = sorted(self.values, key = lambda x: (x[1], x[0])) 471 result = '#define %s %s\n' % ( 472 Globals.naming_style.define_name('_%s_MIN' % self.names), 473 Globals.naming_style.enum_entry(sorted_values[0][0])) 474 result += '#define %s %s\n' % ( 475 Globals.naming_style.define_name('_%s_MAX' % self.names), 476 Globals.naming_style.enum_entry(sorted_values[-1][0])) 477 result += '#define %s ((%s)(%s+1))\n' % ( 478 Globals.naming_style.define_name('_%s_ARRAYSIZE' % self.names), 479 Globals.naming_style.type_name(self.names), 480 Globals.naming_style.enum_entry(sorted_values[-1][0])) 481 482 if not self.options.long_names: 483 # Define the long names always so that enum value references 484 # from other files work properly. 485 for i, x in enumerate(self.values): 486 result += '#define %s %s\n' % (self.value_longnames[i], x[0]) 487 488 if self.options.enum_to_string: 489 result += 'const char *%s(%s v);\n' % ( 490 Globals.naming_style.func_name('%s_name' % self.names), 491 Globals.naming_style.type_name(self.names)) 492 493 return result 494 495 def enum_to_string_definition(self): 496 if not self.options.enum_to_string: 497 return "" 498 499 result = 'const char *%s(%s v) {\n' % ( 500 Globals.naming_style.func_name('%s_name' % self.names), 501 Globals.naming_style.type_name(self.names)) 502 503 result += ' switch (v) {\n' 504 505 for ((enumname, _), strname) in zip(self.values, self.value_longnames): 506 # Strip off the leading type name from the string value. 507 strval = str(strname)[len(str(self.names)) + 1:] 508 result += ' case %s: return "%s";\n' % ( 509 Globals.naming_style.enum_entry(enumname), 510 Globals.naming_style.enum_entry(strval)) 511 512 result += ' }\n' 513 result += ' return "unknown";\n' 514 result += '}\n' 515 516 return result 517 518class FieldMaxSize: 519 def __init__(self, worst = 0, checks = [], field_name = 'undefined'): 520 if isinstance(worst, list): 521 self.worst = max(i for i in worst if i is not None) 522 else: 523 self.worst = worst 524 525 self.worst_field = field_name 526 self.checks = list(checks) 527 528 def extend(self, extend, field_name = None): 529 self.worst = max(self.worst, extend.worst) 530 531 if self.worst == extend.worst: 532 self.worst_field = extend.worst_field 533 534 self.checks.extend(extend.checks) 535 536class Field(ProtoElement): 537 macro_x_param = 'X' 538 macro_a_param = 'a' 539 540 def __init__(self, struct_name, desc, field_options, element_path = (), comments = None): 541 '''desc is FieldDescriptorProto''' 542 ProtoElement.__init__(self, element_path, comments) 543 self.tag = desc.number 544 self.struct_name = struct_name 545 self.union_name = None 546 self.name = desc.name 547 self.default = None 548 self.max_size = None 549 self.max_count = None 550 self.array_decl = "" 551 self.enc_size = None 552 self.data_item_size = None 553 self.ctype = None 554 self.fixed_count = False 555 self.callback_datatype = field_options.callback_datatype 556 self.math_include_required = False 557 self.sort_by_tag = field_options.sort_by_tag 558 559 if field_options.type == nanopb_pb2.FT_INLINE: 560 # Before nanopb-0.3.8, fixed length bytes arrays were specified 561 # by setting type to FT_INLINE. But to handle pointer typed fields, 562 # it makes sense to have it as a separate option. 563 field_options.type = nanopb_pb2.FT_STATIC 564 field_options.fixed_length = True 565 566 # Parse field options 567 if field_options.HasField("max_size"): 568 self.max_size = field_options.max_size 569 570 self.default_has = field_options.default_has 571 572 if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"): 573 # max_length overrides max_size for strings 574 self.max_size = field_options.max_length + 1 575 576 if field_options.HasField("max_count"): 577 self.max_count = field_options.max_count 578 579 if desc.HasField('default_value'): 580 self.default = desc.default_value 581 582 # Check field rules, i.e. required/optional/repeated. 583 can_be_static = True 584 if desc.label == FieldD.LABEL_REPEATED: 585 self.rules = 'REPEATED' 586 if self.max_count is None: 587 can_be_static = False 588 else: 589 self.array_decl = '[%d]' % self.max_count 590 if field_options.fixed_count: 591 self.rules = 'FIXARRAY' 592 593 elif field_options.proto3: 594 if desc.type == FieldD.TYPE_MESSAGE and not field_options.proto3_singular_msgs: 595 # In most other protobuf libraries proto3 submessages have 596 # "null" status. For nanopb, that is implemented as has_ field. 597 self.rules = 'OPTIONAL' 598 elif hasattr(desc, "proto3_optional") and desc.proto3_optional: 599 # Protobuf 3.12 introduced optional fields for proto3 syntax 600 self.rules = 'OPTIONAL' 601 else: 602 # Proto3 singular fields (without has_field) 603 self.rules = 'SINGULAR' 604 elif desc.label == FieldD.LABEL_REQUIRED: 605 self.rules = 'REQUIRED' 606 elif desc.label == FieldD.LABEL_OPTIONAL: 607 self.rules = 'OPTIONAL' 608 else: 609 raise NotImplementedError(desc.label) 610 611 # Check if the field can be implemented with static allocation 612 # i.e. whether the data size is known. 613 if desc.type == FieldD.TYPE_STRING and self.max_size is None: 614 can_be_static = False 615 616 if desc.type == FieldD.TYPE_BYTES and self.max_size is None: 617 can_be_static = False 618 619 # Decide how the field data will be allocated 620 if field_options.type == nanopb_pb2.FT_DEFAULT: 621 if can_be_static: 622 field_options.type = nanopb_pb2.FT_STATIC 623 else: 624 field_options.type = field_options.fallback_type 625 626 if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static: 627 raise Exception("Field '%s' is defined as static, but max_size or " 628 "max_count is not given." % self.name) 629 630 if field_options.fixed_count and self.max_count is None: 631 raise Exception("Field '%s' is defined as fixed count, " 632 "but max_count is not given." % self.name) 633 634 if field_options.type == nanopb_pb2.FT_STATIC: 635 self.allocation = 'STATIC' 636 elif field_options.type == nanopb_pb2.FT_POINTER: 637 self.allocation = 'POINTER' 638 elif field_options.type == nanopb_pb2.FT_CALLBACK: 639 self.allocation = 'CALLBACK' 640 else: 641 raise NotImplementedError(field_options.type) 642 643 if field_options.HasField("type_override"): 644 desc.type = field_options.type_override 645 646 # Decide the C data type to use in the struct. 647 if desc.type in datatypes: 648 self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[desc.type] 649 650 # Override the field size if user wants to use smaller integers 651 if (desc.type, field_options.int_size) in datatypes: 652 self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[(desc.type, field_options.int_size)] 653 elif desc.type == FieldD.TYPE_ENUM: 654 self.pbtype = 'ENUM' 655 self.data_item_size = 4 656 self.ctype = names_from_type_name(desc.type_name) 657 if self.default is not None: 658 self.default = self.ctype + self.default 659 self.enc_size = None # Needs to be filled in when enum values are known 660 elif desc.type == FieldD.TYPE_STRING: 661 self.pbtype = 'STRING' 662 self.ctype = 'char' 663 if self.allocation == 'STATIC': 664 self.ctype = 'char' 665 self.array_decl += '[%d]' % self.max_size 666 # -1 because of null terminator. Both pb_encode and pb_decode 667 # check the presence of it. 668 self.enc_size = varint_max_size(self.max_size) + self.max_size - 1 669 elif desc.type == FieldD.TYPE_BYTES: 670 if field_options.fixed_length: 671 self.pbtype = 'FIXED_LENGTH_BYTES' 672 673 if self.max_size is None: 674 raise Exception("Field '%s' is defined as fixed length, " 675 "but max_size is not given." % self.name) 676 677 self.enc_size = varint_max_size(self.max_size) + self.max_size 678 self.ctype = 'pb_byte_t' 679 self.array_decl += '[%d]' % self.max_size 680 else: 681 self.pbtype = 'BYTES' 682 self.ctype = 'pb_bytes_array_t' 683 if self.allocation == 'STATIC': 684 self.ctype = Globals.naming_style.bytes_type(self.struct_name, self.name) 685 self.enc_size = varint_max_size(self.max_size) + self.max_size 686 elif desc.type == FieldD.TYPE_MESSAGE: 687 self.pbtype = 'MESSAGE' 688 self.ctype = self.submsgname = names_from_type_name(desc.type_name) 689 self.enc_size = None # Needs to be filled in after the message type is available 690 if field_options.submsg_callback and self.allocation == 'STATIC': 691 self.pbtype = 'MSG_W_CB' 692 else: 693 raise NotImplementedError(desc.type) 694 695 if self.default and self.pbtype in ['FLOAT', 'DOUBLE']: 696 if 'inf' in self.default or 'nan' in self.default: 697 self.math_include_required = True 698 699 def __lt__(self, other): 700 return self.tag < other.tag 701 702 def __repr__(self): 703 return 'Field(%s)' % self.name 704 705 def __str__(self): 706 result = '' 707 708 var_name = Globals.naming_style.var_name(self.name) 709 type_name = Globals.naming_style.type_name(self.ctype) if isinstance(self.ctype, Names) else self.ctype 710 711 if self.allocation == 'POINTER': 712 if self.rules == 'REPEATED': 713 if self.pbtype == 'MSG_W_CB': 714 result += ' pb_callback_t cb_' + var_name + ';\n' 715 result += ' pb_size_t ' + var_name + '_count;\n' 716 717 if self.rules == 'FIXARRAY' and self.pbtype in ['STRING', 'BYTES']: 718 # Pointer to fixed size array of pointers 719 result += ' %s* (*%s)%s;' % (type_name, var_name, self.array_decl) 720 elif self.pbtype == 'FIXED_LENGTH_BYTES' or self.rules == 'FIXARRAY': 721 # Pointer to fixed size array of items 722 result += ' %s (*%s)%s;' % (type_name, var_name, self.array_decl) 723 elif self.rules == 'REPEATED' and self.pbtype in ['STRING', 'BYTES']: 724 # String/bytes arrays need to be defined as pointers to pointers 725 result += ' %s **%s;' % (type_name, var_name) 726 elif self.pbtype in ['MESSAGE', 'MSG_W_CB']: 727 # Use struct definition, so recursive submessages are possible 728 result += ' struct %s *%s;' % (Globals.naming_style.struct_name(self.ctype), var_name) 729 else: 730 # Normal case, just a pointer to single item 731 result += ' %s *%s;' % (type_name, var_name) 732 elif self.allocation == 'CALLBACK': 733 result += ' %s %s;' % (self.callback_datatype, var_name) 734 else: 735 if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']: 736 result += ' pb_callback_t cb_' + var_name + ';\n' 737 738 if self.rules == 'OPTIONAL': 739 result += ' bool has_' + var_name + ';\n' 740 elif self.rules == 'REPEATED': 741 result += ' pb_size_t ' + var_name + '_count;\n' 742 743 result += ' %s %s%s;' % (type_name, var_name, self.array_decl) 744 745 leading_comment, trailing_comment = self.get_comments(leading_indent = True) 746 if leading_comment: result = leading_comment + "\n" + result 747 if trailing_comment: result = result + " " + trailing_comment 748 749 return result 750 751 def types(self): 752 '''Return definitions for any special types this field might need.''' 753 if self.pbtype == 'BYTES' and self.allocation == 'STATIC': 754 result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, Globals.naming_style.var_name(self.ctype)) 755 else: 756 result = '' 757 return result 758 759 def get_dependencies(self): 760 '''Get list of type names used by this field.''' 761 if self.allocation == 'STATIC': 762 return [str(self.ctype)] 763 elif self.allocation == 'POINTER' and self.rules == 'FIXARRAY': 764 return [str(self.ctype)] 765 else: 766 return [] 767 768 def get_initializer(self, null_init, inner_init_only = False): 769 '''Return literal expression for this field's default value. 770 null_init: If True, initialize to a 0 value instead of default from .proto 771 inner_init_only: If True, exclude initialization for any count/has fields 772 ''' 773 774 inner_init = None 775 if self.pbtype in ['MESSAGE', 'MSG_W_CB']: 776 if null_init: 777 inner_init = Globals.naming_style.define_name('%s_init_zero' % self.ctype) 778 else: 779 inner_init = Globals.naming_style.define_name('%s_init_default' % self.ctype) 780 elif self.default is None or null_init: 781 if self.pbtype == 'STRING': 782 inner_init = '""' 783 elif self.pbtype == 'BYTES': 784 inner_init = '{0, {0}}' 785 elif self.pbtype == 'FIXED_LENGTH_BYTES': 786 inner_init = '{0}' 787 elif self.pbtype in ('ENUM', 'UENUM'): 788 inner_init = '_%s_MIN' % Globals.naming_style.define_name(self.ctype) 789 else: 790 inner_init = '0' 791 else: 792 if self.pbtype == 'STRING': 793 data = codecs.escape_encode(self.default.encode('utf-8'))[0] 794 inner_init = '"' + data.decode('ascii') + '"' 795 elif self.pbtype == 'BYTES': 796 data = codecs.escape_decode(self.default)[0] 797 data = ["0x%02x" % c for c in bytearray(data)] 798 if len(data) == 0: 799 inner_init = '{0, {0}}' 800 else: 801 inner_init = '{%d, {%s}}' % (len(data), ','.join(data)) 802 elif self.pbtype == 'FIXED_LENGTH_BYTES': 803 data = codecs.escape_decode(self.default)[0] 804 data = ["0x%02x" % c for c in bytearray(data)] 805 if len(data) == 0: 806 inner_init = '{0}' 807 else: 808 inner_init = '{%s}' % ','.join(data) 809 elif self.pbtype in ['FIXED32', 'UINT32']: 810 inner_init = str(self.default) + 'u' 811 elif self.pbtype in ['FIXED64', 'UINT64']: 812 inner_init = str(self.default) + 'ull' 813 elif self.pbtype in ['SFIXED64', 'INT64']: 814 inner_init = str(self.default) + 'll' 815 elif self.pbtype in ['FLOAT', 'DOUBLE']: 816 inner_init = str(self.default) 817 if 'inf' in inner_init: 818 inner_init = inner_init.replace('inf', 'INFINITY') 819 elif 'nan' in inner_init: 820 inner_init = inner_init.replace('nan', 'NAN') 821 elif (not '.' in inner_init) and self.pbtype == 'FLOAT': 822 inner_init += '.0f' 823 elif self.pbtype == 'FLOAT': 824 inner_init += 'f' 825 else: 826 inner_init = str(self.default) 827 828 if inner_init_only: 829 return inner_init 830 831 outer_init = None 832 if self.allocation == 'STATIC': 833 if self.rules == 'REPEATED': 834 outer_init = '0, {' + ', '.join([inner_init] * self.max_count) + '}' 835 elif self.rules == 'FIXARRAY': 836 outer_init = '{' + ', '.join([inner_init] * self.max_count) + '}' 837 elif self.rules == 'OPTIONAL': 838 if null_init or not self.default_has: 839 outer_init = 'false, ' + inner_init 840 else: 841 outer_init = 'true, ' + inner_init 842 else: 843 outer_init = inner_init 844 elif self.allocation == 'POINTER': 845 if self.rules == 'REPEATED': 846 outer_init = '0, NULL' 847 else: 848 outer_init = 'NULL' 849 elif self.allocation == 'CALLBACK': 850 if self.pbtype == 'EXTENSION': 851 outer_init = 'NULL' 852 else: 853 outer_init = '{{NULL}, NULL}' 854 855 if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']: 856 outer_init = '{{NULL}, NULL}, ' + outer_init 857 858 return outer_init 859 860 def tags(self): 861 '''Return the #define for the tag number of this field.''' 862 identifier = Globals.naming_style.define_name('%s_%s_tag' % (self.struct_name, self.name)) 863 return '#define %-40s %d\n' % (identifier, self.tag) 864 865 def fieldlist(self): 866 '''Return the FIELDLIST macro entry for this field. 867 Format is: X(a, ATYPE, HTYPE, LTYPE, field_name, tag) 868 ''' 869 name = Globals.naming_style.var_name(self.name) 870 871 if self.rules == "ONEOF": 872 # For oneofs, make a tuple of the union name, union member name, 873 # and the name inside the parent struct. 874 if not self.anonymous: 875 name = '(%s,%s,%s)' % ( 876 Globals.naming_style.var_name(self.union_name), 877 Globals.naming_style.var_name(self.name), 878 Globals.naming_style.var_name(self.union_name) + '.' + 879 Globals.naming_style.var_name(self.name)) 880 else: 881 name = '(%s,%s,%s)' % ( 882 Globals.naming_style.var_name(self.union_name), 883 Globals.naming_style.var_name(self.name), 884 Globals.naming_style.var_name(self.name)) 885 886 return '%s(%s, %-9s %-9s %-9s %-16s %3d)' % (self.macro_x_param, 887 self.macro_a_param, 888 self.allocation + ',', 889 self.rules + ',', 890 self.pbtype + ',', 891 name + ',', 892 self.tag) 893 894 def data_size(self, dependencies): 895 '''Return estimated size of this field in the C struct. 896 This is used to try to automatically pick right descriptor size. 897 If the estimate is wrong, it will result in compile time error and 898 user having to specify descriptor_width option. 899 ''' 900 if self.allocation == 'POINTER' or self.pbtype == 'EXTENSION': 901 size = 8 902 alignment = 8 903 elif self.allocation == 'CALLBACK': 904 size = 16 905 alignment = 8 906 elif self.pbtype in ['MESSAGE', 'MSG_W_CB']: 907 alignment = 8 908 if str(self.submsgname) in dependencies: 909 other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name)) 910 size = dependencies[str(self.submsgname)].data_size(other_dependencies) 911 else: 912 size = 256 # Message is in other file, this is reasonable guess for most cases 913 sys.stderr.write('Could not determine size for submessage %s, using default %d\n' % (self.submsgname, size)) 914 915 if self.pbtype == 'MSG_W_CB': 916 size += 16 917 elif self.pbtype in ['STRING', 'FIXED_LENGTH_BYTES']: 918 size = self.max_size 919 alignment = 4 920 elif self.pbtype == 'BYTES': 921 size = self.max_size + 4 922 alignment = 4 923 elif self.data_item_size is not None: 924 size = self.data_item_size 925 alignment = 4 926 if self.data_item_size >= 8: 927 alignment = 8 928 else: 929 raise Exception("Unhandled field type: %s" % self.pbtype) 930 931 if self.rules in ['REPEATED', 'FIXARRAY'] and self.allocation == 'STATIC': 932 size *= self.max_count 933 934 if self.rules not in ('REQUIRED', 'SINGULAR'): 935 size += 4 936 937 if size % alignment != 0: 938 # Estimate how much alignment requirements will increase the size. 939 size += alignment - (size % alignment) 940 941 return size 942 943 def encoded_size(self, dependencies): 944 '''Return the maximum size that this field can take when encoded, 945 including the field tag. If the size cannot be determined, returns 946 None.''' 947 948 if self.allocation != 'STATIC': 949 return None 950 951 if self.pbtype in ['MESSAGE', 'MSG_W_CB']: 952 encsize = None 953 if str(self.submsgname) in dependencies: 954 submsg = dependencies[str(self.submsgname)] 955 other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name)) 956 encsize = submsg.encoded_size(other_dependencies) 957 958 my_msg = dependencies.get(str(self.struct_name)) 959 external = (not my_msg or submsg.protofile != my_msg.protofile) 960 961 if encsize and encsize.symbols and external: 962 # Couldn't fully resolve the size of a dependency from 963 # another file. Instead of including the symbols directly, 964 # just use the #define SubMessage_size from the header. 965 encsize = None 966 967 if encsize is not None: 968 # Include submessage length prefix 969 encsize += varint_max_size(encsize.upperlimit()) 970 elif not external: 971 # The dependency is from the same file and size cannot be 972 # determined for it, thus we know it will not be possible 973 # in runtime either. 974 return None 975 976 if encsize is None: 977 # Submessage or its size cannot be found. 978 # This can occur if submessage is defined in different 979 # file, and it or its .options could not be found. 980 # Instead of direct numeric value, reference the size that 981 # has been #defined in the other file. 982 encsize = EncodedSize(self.submsgname + 'size') 983 984 # We will have to make a conservative assumption on the length 985 # prefix size, though. 986 encsize += 5 987 988 elif self.pbtype in ['ENUM', 'UENUM']: 989 if str(self.ctype) in dependencies: 990 enumtype = dependencies[str(self.ctype)] 991 encsize = enumtype.encoded_size() 992 else: 993 # Conservative assumption 994 encsize = 10 995 996 elif self.enc_size is None: 997 raise RuntimeError("Could not determine encoded size for %s.%s" 998 % (self.struct_name, self.name)) 999 else: 1000 encsize = EncodedSize(self.enc_size) 1001 1002 encsize += varint_max_size(self.tag << 3) # Tag + wire type 1003 1004 if self.rules in ['REPEATED', 'FIXARRAY']: 1005 # Decoders must be always able to handle unpacked arrays. 1006 # Therefore we have to reserve space for it, even though 1007 # we emit packed arrays ourselves. For length of 1, packed 1008 # arrays are larger however so we need to add allowance 1009 # for the length byte. 1010 encsize *= self.max_count 1011 1012 if self.max_count == 1: 1013 encsize += 1 1014 1015 return encsize 1016 1017 def has_callbacks(self): 1018 return self.allocation == 'CALLBACK' 1019 1020 def requires_custom_field_callback(self): 1021 return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t' 1022 1023class ExtensionRange(Field): 1024 def __init__(self, struct_name, range_start, field_options): 1025 '''Implements a special pb_extension_t* field in an extensible message 1026 structure. The range_start signifies the index at which the extensions 1027 start. Not necessarily all tags above this are extensions, it is merely 1028 a speed optimization. 1029 ''' 1030 self.tag = range_start 1031 self.struct_name = struct_name 1032 self.name = 'extensions' 1033 self.pbtype = 'EXTENSION' 1034 self.rules = 'OPTIONAL' 1035 self.allocation = 'CALLBACK' 1036 self.ctype = 'pb_extension_t' 1037 self.array_decl = '' 1038 self.default = None 1039 self.max_size = 0 1040 self.max_count = 0 1041 self.data_item_size = 0 1042 self.fixed_count = False 1043 self.callback_datatype = 'pb_extension_t*' 1044 1045 def requires_custom_field_callback(self): 1046 return False 1047 1048 def __str__(self): 1049 return ' pb_extension_t *extensions;' 1050 1051 def types(self): 1052 return '' 1053 1054 def tags(self): 1055 return '' 1056 1057 def encoded_size(self, dependencies): 1058 # We exclude extensions from the count, because they cannot be known 1059 # until runtime. Other option would be to return None here, but this 1060 # way the value remains useful if extensions are not used. 1061 return EncodedSize(0) 1062 1063class ExtensionField(Field): 1064 def __init__(self, fullname, desc, field_options): 1065 self.fullname = fullname 1066 self.extendee_name = names_from_type_name(desc.extendee) 1067 Field.__init__(self, self.fullname + "extmsg", desc, field_options) 1068 1069 if self.rules != 'OPTIONAL': 1070 self.skip = True 1071 else: 1072 self.skip = False 1073 self.rules = 'REQUIRED' # We don't really want the has_field for extensions 1074 # currently no support for comments for extension fields => provide (), {} 1075 self.msg = Message(self.fullname + "extmsg", None, field_options, (), {}) 1076 self.msg.fields.append(self) 1077 1078 def tags(self): 1079 '''Return the #define for the tag number of this field.''' 1080 identifier = Globals.naming_style.define_name('%s_tag' % (self.fullname)) 1081 return '#define %-40s %d\n' % (identifier, self.tag) 1082 1083 def extension_decl(self): 1084 '''Declaration of the extension type in the .pb.h file''' 1085 if self.skip: 1086 msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname 1087 msg +=' type of extension fields is currently supported. */\n' 1088 return msg 1089 1090 return ('extern const pb_extension_type_t %s; /* field type: %s */\n' % 1091 (Globals.naming_style.var_name(self.fullname), str(self).strip())) 1092 1093 def extension_def(self, dependencies): 1094 '''Definition of the extension type in the .pb.c file''' 1095 1096 if self.skip: 1097 return '' 1098 1099 result = "/* Definition for extension field %s */\n" % self.fullname 1100 result += str(self.msg) 1101 result += self.msg.fields_declaration(dependencies) 1102 result += 'pb_byte_t %s_default[] = {0x00};\n' % self.msg.name 1103 result += self.msg.fields_definition(dependencies) 1104 result += 'const pb_extension_type_t %s = {\n' % Globals.naming_style.var_name(self.fullname) 1105 result += ' NULL,\n' 1106 result += ' NULL,\n' 1107 result += ' &%s_msg\n' % Globals.naming_style.type_name(self.msg.name) 1108 result += '};\n' 1109 return result 1110 1111 1112# --------------------------------------------------------------------------- 1113# Generation of oneofs (unions) 1114# --------------------------------------------------------------------------- 1115 1116class OneOf(Field): 1117 def __init__(self, struct_name, oneof_desc, oneof_options): 1118 self.struct_name = struct_name 1119 self.name = oneof_desc.name 1120 self.ctype = 'union' 1121 self.pbtype = 'oneof' 1122 self.fields = [] 1123 self.allocation = 'ONEOF' 1124 self.default = None 1125 self.rules = 'ONEOF' 1126 self.anonymous = oneof_options.anonymous_oneof 1127 self.sort_by_tag = oneof_options.sort_by_tag 1128 self.has_msg_cb = False 1129 1130 def add_field(self, field): 1131 field.union_name = self.name 1132 field.rules = 'ONEOF' 1133 field.anonymous = self.anonymous 1134 self.fields.append(field) 1135 1136 if self.sort_by_tag: 1137 self.fields.sort() 1138 1139 if field.pbtype == 'MSG_W_CB': 1140 self.has_msg_cb = True 1141 1142 # Sort by the lowest tag number inside union 1143 self.tag = min([f.tag for f in self.fields]) 1144 1145 def __str__(self): 1146 result = '' 1147 if self.fields: 1148 if self.has_msg_cb: 1149 result += ' pb_callback_t cb_' + Globals.naming_style.var_name(self.name) + ';\n' 1150 1151 result += ' pb_size_t which_' + Globals.naming_style.var_name(self.name) + ";\n" 1152 result += ' union {\n' 1153 for f in self.fields: 1154 result += ' ' + str(f).replace('\n', '\n ') + '\n' 1155 if self.anonymous: 1156 result += ' };' 1157 else: 1158 result += ' } ' + Globals.naming_style.var_name(self.name) + ';' 1159 return result 1160 1161 def types(self): 1162 return ''.join([f.types() for f in self.fields]) 1163 1164 def get_dependencies(self): 1165 deps = [] 1166 for f in self.fields: 1167 deps += f.get_dependencies() 1168 return deps 1169 1170 def get_initializer(self, null_init): 1171 if self.has_msg_cb: 1172 return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}' 1173 else: 1174 return '0, {' + self.fields[0].get_initializer(null_init) + '}' 1175 1176 def tags(self): 1177 return ''.join([f.tags() for f in self.fields]) 1178 1179 def data_size(self, dependencies): 1180 return max(f.data_size(dependencies) for f in self.fields) 1181 1182 def encoded_size(self, dependencies): 1183 '''Returns the size of the largest oneof field.''' 1184 largest = 0 1185 dynamic_sizes = {} 1186 for f in self.fields: 1187 size = EncodedSize(f.encoded_size(dependencies)) 1188 if size is None or size.value is None: 1189 return None 1190 elif size.symbols: 1191 dynamic_sizes[f.tag] = size 1192 elif size.value > largest: 1193 largest = size.value 1194 1195 if not dynamic_sizes: 1196 # Simple case, all sizes were known at generator time 1197 return EncodedSize(largest) 1198 1199 if largest > 0: 1200 # Some sizes were known, some were not 1201 dynamic_sizes[0] = EncodedSize(largest) 1202 1203 # Couldn't find size for submessage at generation time, 1204 # have to rely on macro resolution at compile time. 1205 if len(dynamic_sizes) == 1: 1206 # Only one symbol was needed 1207 return list(dynamic_sizes.values())[0] 1208 else: 1209 # Use sizeof(union{}) construct to find the maximum size of 1210 # submessages. 1211 union_name = "%s_%s_size_union" % (self.struct_name, self.name) 1212 union_def = 'union %s {%s};\n' % (union_name, ' '.join('char f%d[%s];' % (k, s) for k,s in dynamic_sizes.items())) 1213 required_defs = list(itertools.chain.from_iterable(s.required_defines for k,s in dynamic_sizes.items())) 1214 return EncodedSize(0, ['sizeof(union %s)' % union_name], [union_def], required_defs) 1215 1216 def has_callbacks(self): 1217 return bool([f for f in self.fields if f.has_callbacks()]) 1218 1219 def requires_custom_field_callback(self): 1220 return bool([f for f in self.fields if f.requires_custom_field_callback()]) 1221 1222# --------------------------------------------------------------------------- 1223# Generation of messages (structures) 1224# --------------------------------------------------------------------------- 1225 1226 1227class Message(ProtoElement): 1228 def __init__(self, names, desc, message_options, element_path, comments): 1229 super(Message, self).__init__(element_path, comments) 1230 self.name = names 1231 self.fields = [] 1232 self.oneofs = {} 1233 self.desc = desc 1234 self.math_include_required = False 1235 self.packed = message_options.packed_struct 1236 self.descriptorsize = message_options.descriptorsize 1237 1238 if message_options.msgid: 1239 self.msgid = message_options.msgid 1240 1241 if desc is not None: 1242 self.load_fields(desc, message_options) 1243 1244 self.callback_function = message_options.callback_function 1245 if not message_options.HasField('callback_function'): 1246 # Automatically assign a per-message callback if any field has 1247 # a special callback_datatype. 1248 for field in self.fields: 1249 if field.requires_custom_field_callback(): 1250 self.callback_function = "%s_callback" % self.name 1251 break 1252 1253 def load_fields(self, desc, message_options): 1254 '''Load field list from DescriptorProto''' 1255 1256 no_unions = [] 1257 1258 if hasattr(desc, 'oneof_decl'): 1259 for i, f in enumerate(desc.oneof_decl): 1260 oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name) 1261 if oneof_options.no_unions: 1262 no_unions.append(i) # No union, but add fields normally 1263 elif oneof_options.type == nanopb_pb2.FT_IGNORE: 1264 pass # No union and skip fields also 1265 else: 1266 oneof = OneOf(self.name, f, oneof_options) 1267 self.oneofs[i] = oneof 1268 else: 1269 sys.stderr.write('Note: This Python protobuf library has no OneOf support\n') 1270 1271 for index, f in enumerate(desc.field): 1272 field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) 1273 if field_options.type == nanopb_pb2.FT_IGNORE: 1274 continue 1275 1276 if field_options.descriptorsize > self.descriptorsize: 1277 self.descriptorsize = field_options.descriptorsize 1278 1279 field = Field(self.name, f, field_options, self.element_path + (ProtoElement.FIELD, index), self.comments) 1280 if hasattr(f, 'oneof_index') and f.HasField('oneof_index'): 1281 if hasattr(f, 'proto3_optional') and f.proto3_optional: 1282 no_unions.append(f.oneof_index) 1283 1284 if f.oneof_index in no_unions: 1285 self.fields.append(field) 1286 elif f.oneof_index in self.oneofs: 1287 self.oneofs[f.oneof_index].add_field(field) 1288 1289 if self.oneofs[f.oneof_index] not in self.fields: 1290 self.fields.append(self.oneofs[f.oneof_index]) 1291 else: 1292 self.fields.append(field) 1293 1294 if field.math_include_required: 1295 self.math_include_required = True 1296 1297 if len(desc.extension_range) > 0: 1298 field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') 1299 range_start = min([r.start for r in desc.extension_range]) 1300 if field_options.type != nanopb_pb2.FT_IGNORE: 1301 self.fields.append(ExtensionRange(self.name, range_start, field_options)) 1302 1303 if message_options.sort_by_tag: 1304 self.fields.sort() 1305 1306 def get_dependencies(self): 1307 '''Get list of type names that this structure refers to.''' 1308 deps = [] 1309 for f in self.fields: 1310 deps += f.get_dependencies() 1311 return deps 1312 1313 def __repr__(self): 1314 return 'Message(%s)' % self.name 1315 1316 def __str__(self): 1317 leading_comment, trailing_comment = self.get_comments() 1318 1319 result = '' 1320 if leading_comment: 1321 result = '%s\n' % leading_comment 1322 1323 result += 'typedef struct %s {' % Globals.naming_style.struct_name(self.name) 1324 if trailing_comment: 1325 result += " " + trailing_comment 1326 1327 result += '\n' 1328 1329 if not self.fields: 1330 # Empty structs are not allowed in C standard. 1331 # Therefore add a dummy field if an empty message occurs. 1332 result += ' char dummy_field;' 1333 1334 result += '\n'.join([str(f) for f in self.fields]) 1335 1336 if Globals.protoc_insertion_points: 1337 result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name 1338 1339 result += '\n}' 1340 1341 if self.packed: 1342 result += ' pb_packed' 1343 1344 result += ' %s;' % Globals.naming_style.type_name(self.name) 1345 1346 if self.packed: 1347 result = 'PB_PACKED_STRUCT_START\n' + result 1348 result += '\nPB_PACKED_STRUCT_END' 1349 1350 return result + '\n' 1351 1352 def types(self): 1353 return ''.join([f.types() for f in self.fields]) 1354 1355 def get_initializer(self, null_init): 1356 if not self.fields: 1357 return '{0}' 1358 1359 parts = [] 1360 for field in self.fields: 1361 parts.append(field.get_initializer(null_init)) 1362 return '{' + ', '.join(parts) + '}' 1363 1364 def count_required_fields(self): 1365 '''Returns number of required fields inside this message''' 1366 count = 0 1367 for f in self.fields: 1368 if not isinstance(f, OneOf): 1369 if f.rules == 'REQUIRED': 1370 count += 1 1371 return count 1372 1373 def all_fields(self): 1374 '''Iterate over all fields in this message, including nested OneOfs.''' 1375 for f in self.fields: 1376 if isinstance(f, OneOf): 1377 for f2 in f.fields: 1378 yield f2 1379 else: 1380 yield f 1381 1382 1383 def field_for_tag(self, tag): 1384 '''Given a tag number, return the Field instance.''' 1385 for field in self.all_fields(): 1386 if field.tag == tag: 1387 return field 1388 return None 1389 1390 def count_all_fields(self): 1391 '''Count the total number of fields in this message.''' 1392 count = 0 1393 for f in self.fields: 1394 if isinstance(f, OneOf): 1395 count += len(f.fields) 1396 else: 1397 count += 1 1398 return count 1399 1400 def fields_declaration(self, dependencies): 1401 '''Return X-macro declaration of all fields in this message.''' 1402 Field.macro_x_param = 'X' 1403 Field.macro_a_param = 'a' 1404 while any(field.name == Field.macro_x_param for field in self.all_fields()): 1405 Field.macro_x_param += '_' 1406 while any(field.name == Field.macro_a_param for field in self.all_fields()): 1407 Field.macro_a_param += '_' 1408 1409 # Field descriptor array must be sorted by tag number, pb_common.c relies on it. 1410 sorted_fields = list(self.all_fields()) 1411 sorted_fields.sort(key = lambda x: x.tag) 1412 1413 result = '#define %s_FIELDLIST(%s, %s) \\\n' % ( 1414 Globals.naming_style.define_name(self.name), 1415 Field.macro_x_param, 1416 Field.macro_a_param) 1417 result += ' \\\n'.join(x.fieldlist() for x in sorted_fields) 1418 result += '\n' 1419 1420 has_callbacks = bool([f for f in self.fields if f.has_callbacks()]) 1421 if has_callbacks: 1422 if self.callback_function != 'pb_default_field_callback': 1423 result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function 1424 result += "#define %s_CALLBACK %s\n" % ( 1425 Globals.naming_style.define_name(self.name), 1426 self.callback_function) 1427 else: 1428 result += "#define %s_CALLBACK NULL\n" % Globals.naming_style.define_name(self.name) 1429 1430 defval = self.default_value(dependencies) 1431 if defval: 1432 hexcoded = ''.join("\\x%02x" % ord(defval[i:i+1]) for i in range(len(defval))) 1433 result += '#define %s_DEFAULT (const pb_byte_t*)"%s\\x00"\n' % ( 1434 Globals.naming_style.define_name(self.name), 1435 hexcoded) 1436 else: 1437 result += '#define %s_DEFAULT NULL\n' % Globals.naming_style.define_name(self.name) 1438 1439 for field in sorted_fields: 1440 if field.pbtype in ['MESSAGE', 'MSG_W_CB']: 1441 if field.rules == 'ONEOF': 1442 result += "#define %s_%s_%s_MSGTYPE %s\n" % ( 1443 Globals.naming_style.type_name(self.name), 1444 Globals.naming_style.var_name(field.union_name), 1445 Globals.naming_style.var_name(field.name), 1446 Globals.naming_style.type_name(field.ctype) 1447 ) 1448 else: 1449 result += "#define %s_%s_MSGTYPE %s\n" % ( 1450 Globals.naming_style.type_name(self.name), 1451 Globals.naming_style.var_name(field.name), 1452 Globals.naming_style.type_name(field.ctype) 1453 ) 1454 1455 return result 1456 1457 def enumtype_defines(self): 1458 '''Defines to allow user code to refer to enum type of a specific field''' 1459 result = '' 1460 for field in self.all_fields(): 1461 if field.pbtype in ['ENUM', "UENUM"]: 1462 if field.rules == 'ONEOF': 1463 result += "#define %s_%s_%s_ENUMTYPE %s\n" % ( 1464 Globals.naming_style.type_name(self.name), 1465 Globals.naming_style.var_name(field.union_name), 1466 Globals.naming_style.var_name(field.name), 1467 Globals.naming_style.type_name(field.ctype) 1468 ) 1469 else: 1470 result += "#define %s_%s_ENUMTYPE %s\n" % ( 1471 Globals.naming_style.type_name(self.name), 1472 Globals.naming_style.var_name(field.name), 1473 Globals.naming_style.type_name(field.ctype) 1474 ) 1475 1476 return result 1477 1478 def fields_declaration_cpp_lookup(self): 1479 result = 'template <>\n' 1480 result += 'struct MessageDescriptor<%s> {\n' % (self.name) 1481 result += ' static PB_INLINE_CONSTEXPR const pb_size_t fields_array_length = %d;\n' % (self.count_all_fields()) 1482 result += ' static inline const pb_msgdesc_t* fields() {\n' 1483 result += ' return &%s_msg;\n' % (self.name) 1484 result += ' }\n' 1485 result += '};' 1486 return result 1487 1488 def fields_definition(self, dependencies): 1489 '''Return the field descriptor definition that goes in .pb.c file.''' 1490 width = self.required_descriptor_width(dependencies) 1491 if width == 1: 1492 width = 'AUTO' 1493 1494 result = 'PB_BIND(%s, %s, %s)\n' % ( 1495 Globals.naming_style.define_name(self.name), 1496 Globals.naming_style.type_name(self.name), 1497 width) 1498 return result 1499 1500 def required_descriptor_width(self, dependencies): 1501 '''Estimate how many words are necessary for each field descriptor.''' 1502 if self.descriptorsize != nanopb_pb2.DS_AUTO: 1503 return int(self.descriptorsize) 1504 1505 if not self.fields: 1506 return 1 1507 1508 max_tag = max(field.tag for field in self.all_fields()) 1509 max_offset = self.data_size(dependencies) 1510 max_arraysize = max((field.max_count or 0) for field in self.all_fields()) 1511 max_datasize = max(field.data_size(dependencies) for field in self.all_fields()) 1512 1513 if max_arraysize > 0xFFFF: 1514 return 8 1515 elif (max_tag > 0x3FF or max_offset > 0xFFFF or 1516 max_arraysize > 0x0FFF or max_datasize > 0x0FFF): 1517 return 4 1518 elif max_tag > 0x3F or max_offset > 0xFF: 1519 return 2 1520 else: 1521 # NOTE: Macro logic in pb.h ensures that width 1 will 1522 # be raised to 2 automatically for string/submsg fields 1523 # and repeated fields. Thus only tag and offset need to 1524 # be checked. 1525 return 1 1526 1527 def data_size(self, dependencies): 1528 '''Return approximate sizeof(struct) in the compiled code.''' 1529 return sum(f.data_size(dependencies) for f in self.fields) 1530 1531 def encoded_size(self, dependencies): 1532 '''Return the maximum size that this message can take when encoded. 1533 If the size cannot be determined, returns None. 1534 ''' 1535 size = EncodedSize(0) 1536 for field in self.fields: 1537 fsize = field.encoded_size(dependencies) 1538 if fsize is None: 1539 return None 1540 size += fsize 1541 1542 return size 1543 1544 def default_value(self, dependencies): 1545 '''Generate serialized protobuf message that contains the 1546 default values for optional fields.''' 1547 1548 if not self.desc: 1549 return b'' 1550 1551 if self.desc.options.map_entry: 1552 return b'' 1553 1554 optional_only = copy.deepcopy(self.desc) 1555 1556 # Remove fields without default values 1557 # The iteration is done in reverse order to avoid remove() messing up iteration. 1558 for field in reversed(list(optional_only.field)): 1559 field.ClearField(str('extendee')) 1560 parsed_field = self.field_for_tag(field.number) 1561 if parsed_field is None or parsed_field.allocation != 'STATIC': 1562 optional_only.field.remove(field) 1563 elif (field.label == FieldD.LABEL_REPEATED or 1564 field.type == FieldD.TYPE_MESSAGE): 1565 optional_only.field.remove(field) 1566 elif hasattr(field, 'oneof_index') and field.HasField('oneof_index'): 1567 optional_only.field.remove(field) 1568 elif field.type == FieldD.TYPE_ENUM: 1569 # The partial descriptor doesn't include the enum type 1570 # so we fake it with int64. 1571 enumname = names_from_type_name(field.type_name) 1572 try: 1573 enumtype = dependencies[str(enumname)] 1574 except KeyError: 1575 raise Exception("Could not find enum type %s while generating default values for %s.\n" % (enumname, self.name) 1576 + "Try passing all source files to generator at once, or use -I option.") 1577 1578 if not isinstance(enumtype, Enum): 1579 raise Exception("Expected enum type as %s, got %s" % (enumname, repr(enumtype))) 1580 1581 if field.HasField('default_value'): 1582 defvals = [v for n,v in enumtype.values if n.parts[-1] == field.default_value] 1583 else: 1584 # If no default is specified, the default is the first value. 1585 defvals = [v for n,v in enumtype.values] 1586 if defvals and defvals[0] != 0: 1587 field.type = FieldD.TYPE_INT64 1588 field.default_value = str(defvals[0]) 1589 field.ClearField(str('type_name')) 1590 else: 1591 optional_only.field.remove(field) 1592 elif not field.HasField('default_value'): 1593 optional_only.field.remove(field) 1594 1595 if len(optional_only.field) == 0: 1596 return b'' 1597 1598 optional_only.ClearField(str('oneof_decl')) 1599 optional_only.ClearField(str('nested_type')) 1600 optional_only.ClearField(str('extension')) 1601 optional_only.ClearField(str('enum_type')) 1602 optional_only.name += str(id(self)) 1603 1604 desc = google.protobuf.descriptor.MakeDescriptor(optional_only) 1605 msg = reflection.MakeClass(desc)() 1606 1607 for field in optional_only.field: 1608 if field.type == FieldD.TYPE_STRING: 1609 setattr(msg, field.name, field.default_value) 1610 elif field.type == FieldD.TYPE_BYTES: 1611 setattr(msg, field.name, codecs.escape_decode(field.default_value)[0]) 1612 elif field.type in [FieldD.TYPE_FLOAT, FieldD.TYPE_DOUBLE]: 1613 setattr(msg, field.name, float(field.default_value)) 1614 elif field.type == FieldD.TYPE_BOOL: 1615 setattr(msg, field.name, field.default_value == 'true') 1616 else: 1617 setattr(msg, field.name, int(field.default_value)) 1618 1619 return msg.SerializeToString() 1620 1621 1622# --------------------------------------------------------------------------- 1623# Processing of entire .proto files 1624# --------------------------------------------------------------------------- 1625 1626def iterate_messages(desc, flatten = False, names = Names(), comment_path = ()): 1627 '''Recursively find all messages. For each, yield name, DescriptorProto, comment_path.''' 1628 if hasattr(desc, 'message_type'): 1629 submsgs = desc.message_type 1630 comment_path += (ProtoElement.MESSAGE,) 1631 else: 1632 submsgs = desc.nested_type 1633 comment_path += (ProtoElement.NESTED_TYPE,) 1634 1635 for idx, submsg in enumerate(submsgs): 1636 sub_names = names + submsg.name 1637 sub_path = comment_path + (idx,) 1638 if flatten: 1639 yield Names(submsg.name), submsg, sub_path 1640 else: 1641 yield sub_names, submsg, sub_path 1642 1643 for x in iterate_messages(submsg, flatten, sub_names, sub_path): 1644 yield x 1645 1646def iterate_extensions(desc, flatten = False, names = Names()): 1647 '''Recursively find all extensions. 1648 For each, yield name, FieldDescriptorProto. 1649 ''' 1650 for extension in desc.extension: 1651 yield names, extension 1652 1653 for subname, subdesc, comment_path in iterate_messages(desc, flatten, names): 1654 for extension in subdesc.extension: 1655 yield subname, extension 1656 1657def sort_dependencies(messages): 1658 '''Sort a list of Messages based on dependencies.''' 1659 1660 # Construct first level list of dependencies 1661 dependencies = {} 1662 for message in messages: 1663 dependencies[str(message.name)] = set(message.get_dependencies()) 1664 1665 # Emit messages after all their dependencies have been processed 1666 remaining = list(messages) 1667 remainset = set(str(m.name) for m in remaining) 1668 while remaining: 1669 for candidate in remaining: 1670 if not remainset.intersection(dependencies[str(candidate.name)]): 1671 remaining.remove(candidate) 1672 remainset.remove(str(candidate.name)) 1673 yield candidate 1674 break 1675 else: 1676 sys.stderr.write("Circular dependency in messages: " + ', '.join(remainset) + " (consider changing to FT_POINTER or FT_CALLBACK)\n") 1677 candidate = remaining.pop(0) 1678 remainset.remove(str(candidate.name)) 1679 yield candidate 1680 1681def make_identifier(headername): 1682 '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9''' 1683 result = "" 1684 for c in headername.upper(): 1685 if c.isalnum(): 1686 result += c 1687 else: 1688 result += '_' 1689 return result 1690 1691class MangleNames: 1692 '''Handles conversion of type names according to mangle_names option: 1693 M_NONE = 0; // Default, no typename mangling 1694 M_STRIP_PACKAGE = 1; // Strip current package name 1695 M_FLATTEN = 2; // Only use last path component 1696 M_PACKAGE_INITIALS = 3; // Replace the package name by the initials 1697 ''' 1698 def __init__(self, fdesc, file_options): 1699 self.file_options = file_options 1700 self.mangle_names = file_options.mangle_names 1701 self.flatten = (self.mangle_names == nanopb_pb2.M_FLATTEN) 1702 self.strip_prefix = None 1703 self.replacement_prefix = None 1704 self.name_mapping = {} 1705 self.reverse_name_mapping = {} 1706 self.canonical_base = Names(fdesc.package.split('.')) 1707 1708 if self.mangle_names == nanopb_pb2.M_STRIP_PACKAGE: 1709 self.strip_prefix = "." + fdesc.package 1710 elif self.mangle_names == nanopb_pb2.M_PACKAGE_INITIALS: 1711 self.strip_prefix = "." + fdesc.package 1712 self.replacement_prefix = "" 1713 for part in fdesc.package.split("."): 1714 self.replacement_prefix += part[0] 1715 elif file_options.package: 1716 self.strip_prefix = "." + fdesc.package 1717 self.replacement_prefix = file_options.package 1718 1719 if self.strip_prefix == '.': 1720 self.strip_prefix = '' 1721 1722 if self.replacement_prefix is not None: 1723 self.base_name = Names(self.replacement_prefix.split('.')) 1724 elif fdesc.package: 1725 self.base_name = Names(fdesc.package.split('.')) 1726 else: 1727 self.base_name = Names() 1728 1729 def create_name(self, names): 1730 '''Create name for a new message / enum. 1731 Argument can be either string or Names instance. 1732 ''' 1733 if str(names) not in self.name_mapping: 1734 if self.mangle_names in (nanopb_pb2.M_NONE, nanopb_pb2.M_PACKAGE_INITIALS): 1735 new_name = self.base_name + names 1736 elif self.mangle_names == nanopb_pb2.M_STRIP_PACKAGE: 1737 new_name = Names(names) 1738 elif isinstance(names, Names): 1739 new_name = Names(names.parts[-1]) 1740 else: 1741 new_name = Names(names) 1742 1743 if str(new_name) in self.reverse_name_mapping: 1744 sys.stderr.write("Warning: Duplicate name with mangle_names=%s: %s and %s map to %s\n" % 1745 (self.mangle_names, self.reverse_name_mapping[str(new_name)], names, new_name)) 1746 1747 self.name_mapping[str(names)] = new_name 1748 self.reverse_name_mapping[str(new_name)] = self.canonical_base + names 1749 1750 return self.name_mapping[str(names)] 1751 1752 def mangle_field_typename(self, typename): 1753 '''Mangle type name for a submessage / enum crossreference. 1754 Argument is a string. 1755 ''' 1756 if self.mangle_names == nanopb_pb2.M_FLATTEN: 1757 return "." + typename.split(".")[-1] 1758 1759 if self.strip_prefix is not None and typename.startswith(self.strip_prefix): 1760 if self.replacement_prefix is not None: 1761 return "." + self.replacement_prefix + typename[len(self.strip_prefix):] 1762 else: 1763 return typename[len(self.strip_prefix):] 1764 1765 if self.file_options.package: 1766 return "." + self.replacement_prefix + typename 1767 1768 return typename 1769 1770 def unmangle(self, names): 1771 return self.reverse_name_mapping.get(str(names), names) 1772 1773class ProtoFile: 1774 def __init__(self, fdesc, file_options): 1775 '''Takes a FileDescriptorProto and parses it.''' 1776 self.fdesc = fdesc 1777 self.file_options = file_options 1778 self.dependencies = {} 1779 self.math_include_required = False 1780 self.parse() 1781 for message in self.messages: 1782 if message.math_include_required: 1783 self.math_include_required = True 1784 break 1785 1786 # Some of types used in this file probably come from the file itself. 1787 # Thus it has implicit dependency on itself. 1788 self.add_dependency(self) 1789 1790 def parse(self): 1791 self.enums = [] 1792 self.messages = [] 1793 self.extensions = [] 1794 self.manglenames = MangleNames(self.fdesc, self.file_options) 1795 1796 # process source code comment locations 1797 # ignores any locations that do not contain any comment information 1798 self.comment_locations = { 1799 tuple(location.path): location 1800 for location in self.fdesc.source_code_info.location 1801 if location.leading_comments or location.leading_detached_comments or location.trailing_comments 1802 } 1803 1804 for index, enum in enumerate(self.fdesc.enum_type): 1805 name = self.manglenames.create_name(enum.name) 1806 enum_options = get_nanopb_suboptions(enum, self.file_options, name) 1807 enum_path = (ProtoElement.ENUM, index) 1808 self.enums.append(Enum(name, enum, enum_options, enum_path, self.comment_locations)) 1809 1810 for names, message, comment_path in iterate_messages(self.fdesc, self.manglenames.flatten): 1811 name = self.manglenames.create_name(names) 1812 message_options = get_nanopb_suboptions(message, self.file_options, name) 1813 1814 if message_options.skip_message: 1815 continue 1816 1817 message = copy.deepcopy(message) 1818 for field in message.field: 1819 if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM): 1820 field.type_name = self.manglenames.mangle_field_typename(field.type_name) 1821 1822 self.messages.append(Message(name, message, message_options, comment_path, self.comment_locations)) 1823 for index, enum in enumerate(message.enum_type): 1824 name = self.manglenames.create_name(names + enum.name) 1825 enum_options = get_nanopb_suboptions(enum, message_options, name) 1826 enum_path = comment_path + (ProtoElement.NESTED_ENUM, index) 1827 self.enums.append(Enum(name, enum, enum_options, enum_path, self.comment_locations)) 1828 1829 for names, extension in iterate_extensions(self.fdesc, self.manglenames.flatten): 1830 name = self.manglenames.create_name(names + extension.name) 1831 field_options = get_nanopb_suboptions(extension, self.file_options, name) 1832 1833 extension = copy.deepcopy(extension) 1834 if extension.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM): 1835 extension.type_name = self.manglenames.mangle_field_typename(extension.type_name) 1836 1837 if field_options.type != nanopb_pb2.FT_IGNORE: 1838 self.extensions.append(ExtensionField(name, extension, field_options)) 1839 1840 def add_dependency(self, other): 1841 for enum in other.enums: 1842 self.dependencies[str(enum.names)] = enum 1843 self.dependencies[str(other.manglenames.unmangle(enum.names))] = enum 1844 enum.protofile = other 1845 1846 for msg in other.messages: 1847 self.dependencies[str(msg.name)] = msg 1848 self.dependencies[str(other.manglenames.unmangle(msg.name))] = msg 1849 msg.protofile = other 1850 1851 # Fix field default values where enum short names are used. 1852 for enum in other.enums: 1853 if not enum.options.long_names: 1854 for message in self.messages: 1855 for field in message.all_fields(): 1856 if field.default in enum.value_longnames: 1857 idx = enum.value_longnames.index(field.default) 1858 field.default = enum.values[idx][0] 1859 1860 # Fix field data types where enums have negative values. 1861 for enum in other.enums: 1862 if not enum.has_negative(): 1863 for message in self.messages: 1864 for field in message.all_fields(): 1865 if field.pbtype == 'ENUM' and field.ctype == enum.names: 1866 field.pbtype = 'UENUM' 1867 1868 def generate_header(self, includes, headername, options): 1869 '''Generate content for a header file. 1870 Generates strings, which should be concatenated and stored to file. 1871 ''' 1872 1873 yield '/* Automatically generated nanopb header */\n' 1874 if options.notimestamp: 1875 yield '/* Generated by %s */\n\n' % (nanopb_version) 1876 else: 1877 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 1878 1879 if self.fdesc.package: 1880 symbol = make_identifier(self.fdesc.package + '_' + headername) 1881 else: 1882 symbol = make_identifier(headername) 1883 yield '#ifndef PB_%s_INCLUDED\n' % symbol 1884 yield '#define PB_%s_INCLUDED\n' % symbol 1885 if self.math_include_required: 1886 yield '#include <math.h>\n' 1887 try: 1888 yield options.libformat % ('pb.h') 1889 except TypeError: 1890 # no %s specified - use whatever was passed in as options.libformat 1891 yield options.libformat 1892 yield '\n' 1893 1894 for incfile in self.file_options.include: 1895 # allow including system headers 1896 if (incfile.startswith('<')): 1897 yield '#include %s\n' % incfile 1898 else: 1899 yield options.genformat % incfile 1900 yield '\n' 1901 1902 for incfile in includes: 1903 noext = os.path.splitext(incfile)[0] 1904 yield options.genformat % (noext + options.extension + options.header_extension) 1905 yield '\n' 1906 1907 if Globals.protoc_insertion_points: 1908 yield '/* @@protoc_insertion_point(includes) */\n' 1909 1910 yield '\n' 1911 1912 yield '#if PB_PROTO_HEADER_VERSION != 40\n' 1913 yield '#error Regenerate this file with the current version of nanopb generator.\n' 1914 yield '#endif\n' 1915 yield '\n' 1916 1917 if self.enums: 1918 yield '/* Enum definitions */\n' 1919 for enum in self.enums: 1920 yield str(enum) + '\n\n' 1921 1922 if self.messages: 1923 yield '/* Struct definitions */\n' 1924 for msg in sort_dependencies(self.messages): 1925 yield msg.types() 1926 yield str(msg) + '\n' 1927 yield '\n' 1928 1929 if self.extensions: 1930 yield '/* Extensions */\n' 1931 for extension in self.extensions: 1932 yield extension.extension_decl() 1933 yield '\n' 1934 1935 yield '#ifdef __cplusplus\n' 1936 yield 'extern "C" {\n' 1937 yield '#endif\n\n' 1938 1939 if self.enums: 1940 yield '/* Helper constants for enums */\n' 1941 for enum in self.enums: 1942 yield enum.auxiliary_defines() + '\n' 1943 1944 for msg in self.messages: 1945 yield msg.enumtype_defines() + '\n' 1946 yield '\n' 1947 1948 if self.messages: 1949 yield '/* Initializer values for message structs */\n' 1950 for msg in self.messages: 1951 identifier = Globals.naming_style.define_name('%s_init_default' % msg.name) 1952 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False)) 1953 for msg in self.messages: 1954 identifier = Globals.naming_style.define_name('%s_init_zero' % msg.name) 1955 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True)) 1956 yield '\n' 1957 1958 yield '/* Field tags (for use in manual encoding/decoding) */\n' 1959 for msg in sort_dependencies(self.messages): 1960 for field in msg.fields: 1961 yield field.tags() 1962 for extension in self.extensions: 1963 yield extension.tags() 1964 yield '\n' 1965 1966 yield '/* Struct field encoding specification for nanopb */\n' 1967 for msg in self.messages: 1968 yield msg.fields_declaration(self.dependencies) + '\n' 1969 for msg in self.messages: 1970 yield 'extern const pb_msgdesc_t %s_msg;\n' % Globals.naming_style.type_name(msg.name) 1971 yield '\n' 1972 1973 yield '/* Defines for backwards compatibility with code written before nanopb-0.4.0 */\n' 1974 for msg in self.messages: 1975 yield '#define %s &%s_msg\n' % ( 1976 Globals.naming_style.define_name('%s_fields' % msg.name), 1977 Globals.naming_style.type_name(msg.name)) 1978 yield '\n' 1979 1980 yield '/* Maximum encoded size of messages (where known) */\n' 1981 messagesizes = [] 1982 for msg in self.messages: 1983 identifier = '%s_size' % msg.name 1984 messagesizes.append((identifier, msg.encoded_size(self.dependencies))) 1985 1986 # If we require a symbol from another file, put a preprocessor if statement 1987 # around it to prevent compilation errors if the symbol is not actually available. 1988 local_defines = [identifier for identifier, msize in messagesizes if msize is not None] 1989 1990 # emit size_unions, if any 1991 oneof_sizes = [] 1992 for msg in self.messages: 1993 for f in msg.fields: 1994 if isinstance(f, OneOf): 1995 msize = f.encoded_size(self.dependencies) 1996 if msize is not None: 1997 oneof_sizes.append(msize) 1998 for msize in oneof_sizes: 1999 guard = msize.get_cpp_guard(local_defines) 2000 if guard: 2001 yield guard 2002 yield msize.get_declarations() 2003 if guard: 2004 yield '#endif\n' 2005 2006 guards = {} 2007 for identifier, msize in messagesizes: 2008 if msize is not None: 2009 cpp_guard = msize.get_cpp_guard(local_defines) 2010 if cpp_guard not in guards: 2011 guards[cpp_guard] = set() 2012 guards[cpp_guard].add('#define %-40s %s' % ( 2013 Globals.naming_style.define_name(identifier), msize)) 2014 else: 2015 yield '/* %s depends on runtime parameters */\n' % identifier 2016 for guard, values in guards.items(): 2017 if guard: 2018 yield guard 2019 for v in sorted(values): 2020 yield v 2021 yield '\n' 2022 if guard: 2023 yield '#endif\n' 2024 yield '\n' 2025 2026 if [msg for msg in self.messages if hasattr(msg,'msgid')]: 2027 yield '/* Message IDs (where set with "msgid" option) */\n' 2028 for msg in self.messages: 2029 if hasattr(msg,'msgid'): 2030 yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name) 2031 yield '\n' 2032 2033 symbol = make_identifier(headername.split('.')[0]) 2034 yield '#define %s_MESSAGES \\\n' % symbol 2035 2036 for msg in self.messages: 2037 m = "-1" 2038 msize = msg.encoded_size(self.dependencies) 2039 if msize is not None: 2040 m = msize 2041 if hasattr(msg,'msgid'): 2042 yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name) 2043 yield '\n' 2044 2045 for msg in self.messages: 2046 if hasattr(msg,'msgid'): 2047 yield '#define %s_msgid %d\n' % (msg.name, msg.msgid) 2048 yield '\n' 2049 2050 # Check if there is any name mangling active 2051 pairs = [x for x in self.manglenames.reverse_name_mapping.items() if str(x[0]) != str(x[1])] 2052 if pairs: 2053 yield '/* Mapping from canonical names (mangle_names or overridden package name) */\n' 2054 for shortname, longname in pairs: 2055 yield '#define %s %s\n' % (longname, shortname) 2056 yield '\n' 2057 2058 yield '#ifdef __cplusplus\n' 2059 yield '} /* extern "C" */\n' 2060 yield '#endif\n' 2061 2062 if options.cpp_descriptors: 2063 yield '\n' 2064 yield '#ifdef __cplusplus\n' 2065 yield '/* Message descriptors for nanopb */\n' 2066 yield 'namespace nanopb {\n' 2067 for msg in self.messages: 2068 yield msg.fields_declaration_cpp_lookup() + '\n' 2069 yield '} // namespace nanopb\n' 2070 yield '\n' 2071 yield '#endif /* __cplusplus */\n' 2072 yield '\n' 2073 2074 if Globals.protoc_insertion_points: 2075 yield '/* @@protoc_insertion_point(eof) */\n' 2076 2077 # End of header 2078 yield '\n#endif\n' 2079 2080 def generate_source(self, headername, options): 2081 '''Generate content for a source file.''' 2082 2083 yield '/* Automatically generated nanopb constant definitions */\n' 2084 if options.notimestamp: 2085 yield '/* Generated by %s */\n\n' % (nanopb_version) 2086 else: 2087 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 2088 yield options.genformat % (headername) 2089 yield '\n' 2090 2091 if Globals.protoc_insertion_points: 2092 yield '/* @@protoc_insertion_point(includes) */\n' 2093 2094 yield '#if PB_PROTO_HEADER_VERSION != 40\n' 2095 yield '#error Regenerate this file with the current version of nanopb generator.\n' 2096 yield '#endif\n' 2097 yield '\n' 2098 2099 # Check if any messages exceed the 64 kB limit of 16-bit pb_size_t 2100 exceeds_64kB = [] 2101 for msg in self.messages: 2102 size = msg.data_size(self.dependencies) 2103 if size >= 65536: 2104 exceeds_64kB.append(str(msg.name)) 2105 2106 if exceeds_64kB: 2107 yield '\n/* The following messages exceed 64kB in size: ' + ', '.join(exceeds_64kB) + ' */\n' 2108 yield '\n/* The PB_FIELD_32BIT compilation option must be defined to support messages that exceed 64 kB in size. */\n' 2109 yield '#ifndef PB_FIELD_32BIT\n' 2110 yield '#error Enable PB_FIELD_32BIT to support messages exceeding 64kB in size: ' + ', '.join(exceeds_64kB) + '\n' 2111 yield '#endif\n' 2112 2113 # Generate the message field definitions (PB_BIND() call) 2114 for msg in self.messages: 2115 yield msg.fields_definition(self.dependencies) + '\n\n' 2116 2117 # Generate pb_extension_type_t definitions if extensions are used in proto file 2118 for ext in self.extensions: 2119 yield ext.extension_def(self.dependencies) + '\n' 2120 2121 # Generate enum_name function if enum_to_string option is defined 2122 for enum in self.enums: 2123 yield enum.enum_to_string_definition() + '\n' 2124 2125 # Add checks for numeric limits 2126 if self.messages: 2127 largest_msg = max(self.messages, key = lambda m: m.count_required_fields()) 2128 largest_count = largest_msg.count_required_fields() 2129 if largest_count > 64: 2130 yield '\n/* Check that missing required fields will be properly detected */\n' 2131 yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count 2132 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name 2133 yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count 2134 yield '#endif\n' 2135 2136 # Add check for sizeof(double) 2137 has_double = False 2138 for msg in self.messages: 2139 for field in msg.all_fields(): 2140 if field.ctype == 'double': 2141 has_double = True 2142 2143 if has_double: 2144 yield '\n' 2145 yield '#ifndef PB_CONVERT_DOUBLE_FLOAT\n' 2146 yield '/* On some platforms (such as AVR), double is really float.\n' 2147 yield ' * To be able to encode/decode double on these platforms, you need.\n' 2148 yield ' * to define PB_CONVERT_DOUBLE_FLOAT in pb.h or compiler command line.\n' 2149 yield ' */\n' 2150 yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' 2151 yield '#endif\n' 2152 2153 yield '\n' 2154 2155 if Globals.protoc_insertion_points: 2156 yield '/* @@protoc_insertion_point(eof) */\n' 2157 2158# --------------------------------------------------------------------------- 2159# Options parsing for the .proto files 2160# --------------------------------------------------------------------------- 2161 2162from fnmatch import fnmatchcase 2163 2164def read_options_file(infile): 2165 '''Parse a separate options file to list: 2166 [(namemask, options), ...] 2167 ''' 2168 results = [] 2169 data = infile.read() 2170 data = re.sub(r'/\*.*?\*/', '', data, flags = re.MULTILINE) 2171 data = re.sub(r'//.*?$', '', data, flags = re.MULTILINE) 2172 data = re.sub(r'#.*?$', '', data, flags = re.MULTILINE) 2173 for i, line in enumerate(data.split('\n')): 2174 line = line.strip() 2175 if not line: 2176 continue 2177 2178 parts = line.split(None, 1) 2179 2180 if len(parts) < 2: 2181 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 2182 "Option lines should have space between field name and options. " + 2183 "Skipping line: '%s'\n" % line) 2184 sys.exit(1) 2185 2186 opts = nanopb_pb2.NanoPBOptions() 2187 2188 try: 2189 text_format.Merge(parts[1], opts) 2190 except Exception as e: 2191 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 2192 "Unparsable option line: '%s'. " % line + 2193 "Error: %s\n" % str(e)) 2194 sys.exit(1) 2195 results.append((parts[0], opts)) 2196 2197 return results 2198 2199def get_nanopb_suboptions(subdesc, options, name): 2200 '''Get copy of options, and merge information from subdesc.''' 2201 new_options = nanopb_pb2.NanoPBOptions() 2202 new_options.CopyFrom(options) 2203 2204 if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3": 2205 new_options.proto3 = True 2206 2207 # Handle options defined in a separate file 2208 dotname = '.'.join(name.parts) 2209 for namemask, options in Globals.separate_options: 2210 if fnmatchcase(dotname, namemask): 2211 Globals.matched_namemasks.add(namemask) 2212 new_options.MergeFrom(options) 2213 2214 # Handle options defined in .proto 2215 if isinstance(subdesc.options, descriptor.FieldOptions): 2216 ext_type = nanopb_pb2.nanopb 2217 elif isinstance(subdesc.options, descriptor.FileOptions): 2218 ext_type = nanopb_pb2.nanopb_fileopt 2219 elif isinstance(subdesc.options, descriptor.MessageOptions): 2220 ext_type = nanopb_pb2.nanopb_msgopt 2221 elif isinstance(subdesc.options, descriptor.EnumOptions): 2222 ext_type = nanopb_pb2.nanopb_enumopt 2223 else: 2224 raise Exception("Unknown options type") 2225 2226 if subdesc.options.HasExtension(ext_type): 2227 ext = subdesc.options.Extensions[ext_type] 2228 new_options.MergeFrom(ext) 2229 2230 if Globals.verbose_options: 2231 sys.stderr.write("Options for " + dotname + ": ") 2232 sys.stderr.write(text_format.MessageToString(new_options) + "\n") 2233 2234 return new_options 2235 2236 2237# --------------------------------------------------------------------------- 2238# Command line interface 2239# --------------------------------------------------------------------------- 2240 2241import sys 2242import os.path 2243from optparse import OptionParser 2244 2245optparser = OptionParser( 2246 usage = "Usage: nanopb_generator.py [options] file.pb ...", 2247 epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " + 2248 "Output will be written to file.pb.h and file.pb.c.") 2249optparser.add_option("--version", dest="version", action="store_true", 2250 help="Show version info and exit") 2251optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[], 2252 help="Exclude file from generated #include list.") 2253optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb", 2254 help="Set extension to use instead of '.pb' for generated files. [default: %default]") 2255optparser.add_option("-H", "--header-extension", dest="header_extension", metavar="EXTENSION", default=".h", 2256 help="Set extension to use for generated header files. [default: %default]") 2257optparser.add_option("-S", "--source-extension", dest="source_extension", metavar="EXTENSION", default=".c", 2258 help="Set extension to use for generated source files. [default: %default]") 2259optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options", 2260 help="Set name of a separate generator options file.") 2261optparser.add_option("-I", "--options-path", "--proto-path", dest="options_path", metavar="DIR", 2262 action="append", default = [], 2263 help="Search path for .options and .proto files. Also determines relative paths for output directory structure.") 2264optparser.add_option("--error-on-unmatched", dest="error_on_unmatched", action="store_true", default=False, 2265 help ="Stop generation if there are unmatched fields in options file") 2266optparser.add_option("--no-error-on-unmatched", dest="error_on_unmatched", action="store_false", default=False, 2267 help ="Continue generation if there are unmatched fields in options file (default)") 2268optparser.add_option("-D", "--output-dir", dest="output_dir", 2269 metavar="OUTPUTDIR", default=None, 2270 help="Output directory of .pb.h and .pb.c files") 2271optparser.add_option("-Q", "--generated-include-format", dest="genformat", 2272 metavar="FORMAT", default='#include "%s"', 2273 help="Set format string to use for including other .pb.h files. Value can be 'quote', 'bracket' or a format string. [default: %default]") 2274optparser.add_option("-L", "--library-include-format", dest="libformat", 2275 metavar="FORMAT", default='#include <%s>', 2276 help="Set format string to use for including the nanopb pb.h header. Value can be 'quote', 'bracket' or a format string. [default: %default]") 2277optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=False, 2278 help="Strip directory path from #included .pb.h file name") 2279optparser.add_option("--no-strip-path", dest="strip_path", action="store_false", 2280 help="Opposite of --strip-path (default since 0.4.0)") 2281optparser.add_option("--cpp-descriptors", action="store_true", 2282 help="Generate C++ descriptors to lookup by type (e.g. pb_field_t for a message)") 2283optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=True, 2284 help="Don't add timestamp to .pb.h and .pb.c preambles (default since 0.4.0)") 2285optparser.add_option("-t", "--timestamp", dest="notimestamp", action="store_false", default=True, 2286 help="Add timestamp to .pb.h and .pb.c preambles") 2287optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, 2288 help="Don't print anything except errors.") 2289optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False, 2290 help="Print more information.") 2291optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[], 2292 help="Set generator option (max_size, max_count etc.).") 2293optparser.add_option("--protoc-opt", dest="protoc_opts", action="append", default = [], metavar="OPTION", 2294 help="Pass an option to protoc when compiling .proto files") 2295optparser.add_option("--protoc-insertion-points", dest="protoc_insertion_points", action="store_true", default=False, 2296 help="Include insertion point comments in output for use by custom protoc plugins") 2297optparser.add_option("-C", "--c-style", dest="c_style", action="store_true", default=False, 2298 help="Use C naming convention.") 2299 2300def process_cmdline(args, is_plugin): 2301 '''Process command line options. Returns list of options, filenames.''' 2302 2303 options, filenames = optparser.parse_args(args) 2304 2305 if options.version: 2306 if is_plugin: 2307 sys.stderr.write('%s\n' % (nanopb_version)) 2308 else: 2309 print(nanopb_version) 2310 sys.exit(0) 2311 2312 if not filenames and not is_plugin: 2313 optparser.print_help() 2314 sys.exit(1) 2315 2316 if options.quiet: 2317 options.verbose = False 2318 2319 include_formats = {'quote': '#include "%s"', 'bracket': '#include <%s>'} 2320 options.libformat = include_formats.get(options.libformat, options.libformat) 2321 options.genformat = include_formats.get(options.genformat, options.genformat) 2322 2323 if options.c_style: 2324 Globals.naming_style = NamingStyleC() 2325 2326 Globals.verbose_options = options.verbose 2327 2328 if options.verbose: 2329 sys.stderr.write("Nanopb version %s\n" % nanopb_version) 2330 sys.stderr.write('Google Python protobuf library imported from %s, version %s\n' 2331 % (google.protobuf.__file__, google.protobuf.__version__)) 2332 2333 return options, filenames 2334 2335 2336def parse_file(filename, fdesc, options): 2337 '''Parse a single file. Returns a ProtoFile instance.''' 2338 toplevel_options = nanopb_pb2.NanoPBOptions() 2339 for s in options.settings: 2340 if ':' not in s and '=' in s: 2341 s = s.replace('=', ':') 2342 text_format.Merge(s, toplevel_options) 2343 2344 if not fdesc: 2345 data = open(filename, 'rb').read() 2346 fdesc = descriptor.FileDescriptorSet.FromString(data).file[0] 2347 2348 # Check if there is a separate .options file 2349 had_abspath = False 2350 try: 2351 optfilename = options.options_file % os.path.splitext(filename)[0] 2352 except TypeError: 2353 # No %s specified, use the filename as-is 2354 optfilename = options.options_file 2355 had_abspath = True 2356 2357 paths = ['.'] + options.options_path 2358 for p in paths: 2359 if os.path.isfile(os.path.join(p, optfilename)): 2360 optfilename = os.path.join(p, optfilename) 2361 if options.verbose: 2362 sys.stderr.write('Reading options from ' + optfilename + '\n') 2363 Globals.separate_options = read_options_file(open(optfilename, 'r', encoding = 'utf-8')) 2364 break 2365 else: 2366 # If we are given a full filename and it does not exist, give an error. 2367 # However, don't give error when we automatically look for .options file 2368 # with the same name as .proto. 2369 if options.verbose or had_abspath: 2370 sys.stderr.write('Options file not found: ' + optfilename + '\n') 2371 Globals.separate_options = [] 2372 2373 Globals.matched_namemasks = set() 2374 Globals.protoc_insertion_points = options.protoc_insertion_points 2375 2376 # Parse the file 2377 file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename])) 2378 f = ProtoFile(fdesc, file_options) 2379 f.optfilename = optfilename 2380 2381 return f 2382 2383def process_file(filename, fdesc, options, other_files = {}): 2384 '''Process a single file. 2385 filename: The full path to the .proto or .pb source file, as string. 2386 fdesc: The loaded FileDescriptorSet, or None to read from the input file. 2387 options: Command line options as they come from OptionsParser. 2388 2389 Returns a dict: 2390 {'headername': Name of header file, 2391 'headerdata': Data for the .h header file, 2392 'sourcename': Name of the source code file, 2393 'sourcedata': Data for the .c source code file 2394 } 2395 ''' 2396 f = parse_file(filename, fdesc, options) 2397 2398 # Check the list of dependencies, and if they are available in other_files, 2399 # add them to be considered for import resolving. Recursively add any files 2400 # imported by the dependencies. 2401 deps = list(f.fdesc.dependency) 2402 while deps: 2403 dep = deps.pop(0) 2404 if dep in other_files: 2405 f.add_dependency(other_files[dep]) 2406 deps += list(other_files[dep].fdesc.dependency) 2407 2408 # Decide the file names 2409 noext = os.path.splitext(filename)[0] 2410 headername = noext + options.extension + options.header_extension 2411 sourcename = noext + options.extension + options.source_extension 2412 2413 if options.strip_path: 2414 headerbasename = os.path.basename(headername) 2415 else: 2416 headerbasename = headername 2417 2418 # List of .proto files that should not be included in the C header file 2419 # even if they are mentioned in the source .proto. 2420 excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude + list(f.file_options.exclude) 2421 includes = [d for d in f.fdesc.dependency if d not in excludes] 2422 2423 headerdata = ''.join(f.generate_header(includes, headerbasename, options)) 2424 sourcedata = ''.join(f.generate_source(headerbasename, options)) 2425 2426 # Check if there were any lines in .options that did not match a member 2427 unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks] 2428 if unmatched: 2429 if options.error_on_unmatched: 2430 raise Exception("Following patterns in " + f.optfilename + " did not match any fields: " 2431 + ', '.join(unmatched)); 2432 elif not options.quiet: 2433 sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: " 2434 + ', '.join(unmatched) + "\n") 2435 2436 if not Globals.verbose_options: 2437 sys.stderr.write("Use protoc --nanopb-out=-v:. to see a list of the field names.\n") 2438 2439 return {'headername': headername, 'headerdata': headerdata, 2440 'sourcename': sourcename, 'sourcedata': sourcedata} 2441 2442def main_cli(): 2443 '''Main function when invoked directly from the command line.''' 2444 2445 options, filenames = process_cmdline(sys.argv[1:], is_plugin = False) 2446 2447 if options.output_dir and not os.path.exists(options.output_dir): 2448 optparser.print_help() 2449 sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir) 2450 sys.exit(1) 2451 2452 # Load .pb files into memory and compile any .proto files. 2453 include_path = ['-I%s' % p for p in options.options_path] 2454 all_fdescs = {} 2455 out_fdescs = {} 2456 for filename in filenames: 2457 if filename.endswith(".proto"): 2458 with TemporaryDirectory() as tmpdir: 2459 tmpname = os.path.join(tmpdir, os.path.basename(filename) + ".pb") 2460 args = ["protoc"] + include_path 2461 args += options.protoc_opts 2462 args += ['--include_imports', '--include_source_info', '-o' + tmpname, filename] 2463 status = invoke_protoc(args) 2464 if status != 0: sys.exit(status) 2465 data = open(tmpname, 'rb').read() 2466 else: 2467 data = open(filename, 'rb').read() 2468 2469 fdescs = descriptor.FileDescriptorSet.FromString(data).file 2470 last_fdesc = fdescs[-1] 2471 2472 for fdesc in fdescs: 2473 all_fdescs[fdesc.name] = fdesc 2474 2475 out_fdescs[last_fdesc.name] = last_fdesc 2476 2477 # Process any include files first, in order to have them 2478 # available as dependencies 2479 other_files = {} 2480 for fdesc in all_fdescs.values(): 2481 other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options) 2482 2483 # Then generate the headers / sources 2484 for fdesc in out_fdescs.values(): 2485 results = process_file(fdesc.name, fdesc, options, other_files) 2486 2487 base_dir = options.output_dir or '' 2488 to_write = [ 2489 (os.path.join(base_dir, results['headername']), results['headerdata']), 2490 (os.path.join(base_dir, results['sourcename']), results['sourcedata']), 2491 ] 2492 2493 if not options.quiet: 2494 paths = " and ".join([x[0] for x in to_write]) 2495 sys.stderr.write("Writing to %s\n" % paths) 2496 2497 for path, data in to_write: 2498 dirname = os.path.dirname(path) 2499 if dirname and not os.path.exists(dirname): 2500 os.makedirs(dirname) 2501 2502 with open(path, 'w') as f: 2503 f.write(data) 2504 2505def main_plugin(): 2506 '''Main function when invoked as a protoc plugin.''' 2507 2508 import io, sys 2509 if sys.platform == "win32": 2510 import os, msvcrt 2511 # Set stdin and stdout to binary mode 2512 msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) 2513 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) 2514 2515 data = io.open(sys.stdin.fileno(), "rb").read() 2516 2517 request = plugin_pb2.CodeGeneratorRequest.FromString(data) 2518 2519 try: 2520 # Versions of Python prior to 2.7.3 do not support unicode 2521 # input to shlex.split(). Try to convert to str if possible. 2522 params = str(request.parameter) 2523 except UnicodeEncodeError: 2524 params = request.parameter 2525 2526 if ',' not in params and ' -' in params: 2527 # Nanopb has traditionally supported space as separator in options 2528 args = shlex.split(params) 2529 else: 2530 # Protoc separates options passed to plugins by comma 2531 # This allows also giving --nanopb_opt option multiple times. 2532 lex = shlex.shlex(params) 2533 lex.whitespace_split = True 2534 lex.whitespace = ',' 2535 lex.commenters = '' 2536 args = list(lex) 2537 2538 optparser.usage = "protoc --nanopb_out=outdir [--nanopb_opt=option] ['--nanopb_opt=option with spaces'] file.proto" 2539 optparser.epilog = "Output will be written to file.pb.h and file.pb.c." 2540 2541 if '-h' in args or '--help' in args: 2542 # By default optparser prints help to stdout, which doesn't work for 2543 # protoc plugins. 2544 optparser.print_help(sys.stderr) 2545 sys.exit(1) 2546 2547 options, dummy = process_cmdline(args, is_plugin = True) 2548 2549 response = plugin_pb2.CodeGeneratorResponse() 2550 2551 # Google's protoc does not currently indicate the full path of proto files. 2552 # Instead always add the main file path to the search dirs, that works for 2553 # the common case. 2554 import os.path 2555 options.options_path.append(os.path.dirname(request.file_to_generate[0])) 2556 2557 # Process any include files first, in order to have them 2558 # available as dependencies 2559 other_files = {} 2560 for fdesc in request.proto_file: 2561 other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options) 2562 2563 for filename in request.file_to_generate: 2564 for fdesc in request.proto_file: 2565 if fdesc.name == filename: 2566 results = process_file(filename, fdesc, options, other_files) 2567 2568 f = response.file.add() 2569 f.name = results['headername'] 2570 f.content = results['headerdata'] 2571 2572 f = response.file.add() 2573 f.name = results['sourcename'] 2574 f.content = results['sourcedata'] 2575 2576 if hasattr(plugin_pb2.CodeGeneratorResponse, "FEATURE_PROTO3_OPTIONAL"): 2577 response.supported_features = plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL 2578 2579 io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString()) 2580 2581if __name__ == '__main__': 2582 # Check if we are running as a plugin under protoc 2583 if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv: 2584 main_plugin() 2585 else: 2586 main_cli() 2587