1#!/usr/bin/env python3 2# kate: replace-tabs on; indent-width 4; 3 4from __future__ import unicode_literals 5 6'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.''' 7nanopb_version = "nanopb-0.4.6-dev" 8 9import sys 10import re 11import codecs 12import copy 13import itertools 14import tempfile 15import shutil 16import os 17from functools import reduce 18 19try: 20 # Add some dummy imports to keep packaging tools happy. 21 import google, distutils.util # bbfreeze seems to need these 22 import pkg_resources # pyinstaller / protobuf 2.5 seem to need these 23 import proto.nanopb_pb2 as nanopb_pb2 # pyinstaller seems to need this 24 import pkg_resources.py2_warn 25except: 26 # Don't care, we will error out later if it is actually important. 27 pass 28 29try: 30 # Make sure grpc_tools gets included in binary package if it is available 31 import grpc_tools.protoc 32except: 33 pass 34 35try: 36 import google.protobuf.text_format as text_format 37 import google.protobuf.descriptor_pb2 as descriptor 38 import google.protobuf.compiler.plugin_pb2 as plugin_pb2 39 import google.protobuf.reflection as reflection 40 import google.protobuf.descriptor 41except: 42 sys.stderr.write(''' 43 ************************************************************* 44 *** Could not import the Google protobuf Python libraries *** 45 *** Try installing package 'python3-protobuf' or similar. *** 46 ************************************************************* 47 ''' + '\n') 48 raise 49 50try: 51 from .proto import nanopb_pb2 52 from .proto._utils import invoke_protoc 53except TypeError: 54 sys.stderr.write(''' 55 **************************************************************************** 56 *** Got TypeError when importing the protocol definitions for generator. *** 57 *** This usually means that the protoc in your path doesn't match the *** 58 *** Python protobuf library version. *** 59 *** *** 60 *** Please check the output of the following commands: *** 61 *** which protoc *** 62 *** protoc --version *** 63 *** python3 -c 'import google.protobuf; print(google.protobuf.__file__)' *** 64 *** If you are not able to find the python protobuf version using the *** 65 *** above command, use this command. *** 66 *** pip freeze | grep -i protobuf *** 67 **************************************************************************** 68 ''' + '\n') 69 raise 70except (ValueError, SystemError, ImportError): 71 # Probably invoked directly instead of via installed scripts. 72 import proto.nanopb_pb2 as nanopb_pb2 73 from proto._utils import invoke_protoc 74except: 75 sys.stderr.write(''' 76 ******************************************************************** 77 *** Failed to import the protocol definitions for generator. *** 78 *** You have to run 'make' in the nanopb/generator/proto folder. *** 79 ******************************************************************** 80 ''' + '\n') 81 raise 82 83try: 84 from tempfile import TemporaryDirectory 85except ImportError: 86 class TemporaryDirectory: 87 '''TemporaryDirectory fallback for Python 2''' 88 def __enter__(self): 89 self.dir = tempfile.mkdtemp() 90 return self.dir 91 92 def __exit__(self, *args): 93 shutil.rmtree(self.dir) 94 95# --------------------------------------------------------------------------- 96# Generation of single fields 97# --------------------------------------------------------------------------- 98 99import time 100import os.path 101 102# Values are tuple (c type, pb type, encoded size, data_size) 103FieldD = descriptor.FieldDescriptorProto 104datatypes = { 105 FieldD.TYPE_BOOL: ('bool', 'BOOL', 1, 4), 106 FieldD.TYPE_DOUBLE: ('double', 'DOUBLE', 8, 8), 107 FieldD.TYPE_FIXED32: ('uint32_t', 'FIXED32', 4, 4), 108 FieldD.TYPE_FIXED64: ('uint64_t', 'FIXED64', 8, 8), 109 FieldD.TYPE_FLOAT: ('float', 'FLOAT', 4, 4), 110 FieldD.TYPE_INT32: ('int32_t', 'INT32', 10, 4), 111 FieldD.TYPE_INT64: ('int64_t', 'INT64', 10, 8), 112 FieldD.TYPE_SFIXED32: ('int32_t', 'SFIXED32', 4, 4), 113 FieldD.TYPE_SFIXED64: ('int64_t', 'SFIXED64', 8, 8), 114 FieldD.TYPE_SINT32: ('int32_t', 'SINT32', 5, 4), 115 FieldD.TYPE_SINT64: ('int64_t', 'SINT64', 10, 8), 116 FieldD.TYPE_UINT32: ('uint32_t', 'UINT32', 5, 4), 117 FieldD.TYPE_UINT64: ('uint64_t', 'UINT64', 10, 8), 118 119 # Integer size override options 120 (FieldD.TYPE_INT32, nanopb_pb2.IS_8): ('int8_t', 'INT32', 10, 1), 121 (FieldD.TYPE_INT32, nanopb_pb2.IS_16): ('int16_t', 'INT32', 10, 2), 122 (FieldD.TYPE_INT32, nanopb_pb2.IS_32): ('int32_t', 'INT32', 10, 4), 123 (FieldD.TYPE_INT32, nanopb_pb2.IS_64): ('int64_t', 'INT32', 10, 8), 124 (FieldD.TYPE_SINT32, nanopb_pb2.IS_8): ('int8_t', 'SINT32', 2, 1), 125 (FieldD.TYPE_SINT32, nanopb_pb2.IS_16): ('int16_t', 'SINT32', 3, 2), 126 (FieldD.TYPE_SINT32, nanopb_pb2.IS_32): ('int32_t', 'SINT32', 5, 4), 127 (FieldD.TYPE_SINT32, nanopb_pb2.IS_64): ('int64_t', 'SINT32', 10, 8), 128 (FieldD.TYPE_UINT32, nanopb_pb2.IS_8): ('uint8_t', 'UINT32', 2, 1), 129 (FieldD.TYPE_UINT32, nanopb_pb2.IS_16): ('uint16_t','UINT32', 3, 2), 130 (FieldD.TYPE_UINT32, nanopb_pb2.IS_32): ('uint32_t','UINT32', 5, 4), 131 (FieldD.TYPE_UINT32, nanopb_pb2.IS_64): ('uint64_t','UINT32', 10, 8), 132 (FieldD.TYPE_INT64, nanopb_pb2.IS_8): ('int8_t', 'INT64', 10, 1), 133 (FieldD.TYPE_INT64, nanopb_pb2.IS_16): ('int16_t', 'INT64', 10, 2), 134 (FieldD.TYPE_INT64, nanopb_pb2.IS_32): ('int32_t', 'INT64', 10, 4), 135 (FieldD.TYPE_INT64, nanopb_pb2.IS_64): ('int64_t', 'INT64', 10, 8), 136 (FieldD.TYPE_SINT64, nanopb_pb2.IS_8): ('int8_t', 'SINT64', 2, 1), 137 (FieldD.TYPE_SINT64, nanopb_pb2.IS_16): ('int16_t', 'SINT64', 3, 2), 138 (FieldD.TYPE_SINT64, nanopb_pb2.IS_32): ('int32_t', 'SINT64', 5, 4), 139 (FieldD.TYPE_SINT64, nanopb_pb2.IS_64): ('int64_t', 'SINT64', 10, 8), 140 (FieldD.TYPE_UINT64, nanopb_pb2.IS_8): ('uint8_t', 'UINT64', 2, 1), 141 (FieldD.TYPE_UINT64, nanopb_pb2.IS_16): ('uint16_t','UINT64', 3, 2), 142 (FieldD.TYPE_UINT64, nanopb_pb2.IS_32): ('uint32_t','UINT64', 5, 4), 143 (FieldD.TYPE_UINT64, nanopb_pb2.IS_64): ('uint64_t','UINT64', 10, 8), 144} 145 146class Globals: 147 '''Ugly global variables, should find a good way to pass these.''' 148 verbose_options = False 149 separate_options = [] 150 matched_namemasks = set() 151 protoc_insertion_points = False 152 153# String types (for python 2 / python 3 compatibility) 154try: 155 strtypes = (unicode, str) 156 openmode_unicode = 'rU' 157except NameError: 158 strtypes = (str, ) 159 openmode_unicode = 'r' 160 161 162class Names: 163 '''Keeps a set of nested names and formats them to C identifier.''' 164 def __init__(self, parts = ()): 165 if isinstance(parts, Names): 166 parts = parts.parts 167 elif isinstance(parts, strtypes): 168 parts = (parts,) 169 self.parts = tuple(parts) 170 171 def __str__(self): 172 return '_'.join(self.parts) 173 174 def __add__(self, other): 175 if isinstance(other, strtypes): 176 return Names(self.parts + (other,)) 177 elif isinstance(other, Names): 178 return Names(self.parts + other.parts) 179 elif isinstance(other, tuple): 180 return Names(self.parts + other) 181 else: 182 raise ValueError("Name parts should be of type str") 183 184 def __eq__(self, other): 185 return isinstance(other, Names) and self.parts == other.parts 186 187 def __lt__(self, other): 188 if not isinstance(other, Names): 189 return NotImplemented 190 return str(self) < str(other) 191 192def names_from_type_name(type_name): 193 '''Parse Names() from FieldDescriptorProto type_name''' 194 if type_name[0] != '.': 195 raise NotImplementedError("Lookup of non-absolute type names is not supported") 196 return Names(type_name[1:].split('.')) 197 198def varint_max_size(max_value): 199 '''Returns the maximum number of bytes a varint can take when encoded.''' 200 if max_value < 0: 201 max_value = 2**64 - max_value 202 for i in range(1, 11): 203 if (max_value >> (i * 7)) == 0: 204 return i 205 raise ValueError("Value too large for varint: " + str(max_value)) 206 207assert varint_max_size(-1) == 10 208assert varint_max_size(0) == 1 209assert varint_max_size(127) == 1 210assert varint_max_size(128) == 2 211 212class EncodedSize: 213 '''Class used to represent the encoded size of a field or a message. 214 Consists of a combination of symbolic sizes and integer sizes.''' 215 def __init__(self, value = 0, symbols = [], declarations = [], required_defines = []): 216 if isinstance(value, EncodedSize): 217 self.value = value.value 218 self.symbols = value.symbols 219 self.declarations = value.declarations 220 self.required_defines = value.required_defines 221 elif isinstance(value, strtypes + (Names,)): 222 self.symbols = [str(value)] 223 self.value = 0 224 self.declarations = [] 225 self.required_defines = [str(value)] 226 else: 227 self.value = value 228 self.symbols = symbols 229 self.declarations = declarations 230 self.required_defines = required_defines 231 232 def __add__(self, other): 233 if isinstance(other, int): 234 return EncodedSize(self.value + other, self.symbols, self.declarations, self.required_defines) 235 elif isinstance(other, strtypes + (Names,)): 236 return EncodedSize(self.value, self.symbols + [str(other)], self.declarations, self.required_defines + [str(other)]) 237 elif isinstance(other, EncodedSize): 238 return EncodedSize(self.value + other.value, self.symbols + other.symbols, 239 self.declarations + other.declarations, self.required_defines + other.required_defines) 240 else: 241 raise ValueError("Cannot add size: " + repr(other)) 242 243 def __mul__(self, other): 244 if isinstance(other, int): 245 return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols], 246 self.declarations, self.required_defines) 247 else: 248 raise ValueError("Cannot multiply size: " + repr(other)) 249 250 def __str__(self): 251 if not self.symbols: 252 return str(self.value) 253 else: 254 return '(' + str(self.value) + ' + ' + ' + '.join(self.symbols) + ')' 255 256 def get_declarations(self): 257 '''Get any declarations that must appear alongside this encoded size definition, 258 such as helper union {} types.''' 259 return '\n'.join(self.declarations) 260 261 def get_cpp_guard(self, local_defines): 262 '''Get an #if preprocessor statement listing all defines that are required for this definition.''' 263 needed = [x for x in self.required_defines if x not in local_defines] 264 if needed: 265 return '#if ' + ' && '.join(['defined(%s)' % x for x in needed]) + "\n" 266 else: 267 return '' 268 269 def upperlimit(self): 270 if not self.symbols: 271 return self.value 272 else: 273 return 2**32 - 1 274 275 276''' 277Constants regarding path of proto elements in file descriptor. 278They are used to connect proto elements with source code information (comments) 279These values come from: 280 https://github.com/google/protobuf/blob/master/src/google/protobuf/descriptor.proto 281''' 282MESSAGE_PATH = 4 283ENUM_PATH = 5 284FIELD_PATH = 2 285 286 287class ProtoElement(object): 288 def __init__(self, path, index, comments): 289 ''' 290 path is a predefined value for each element type in proto file. 291 For example, message == 4, enum == 5, service == 6 292 index is the N-th occurance of the `path` in the proto file. 293 For example, 4-th message in the proto file or 2-nd enum etc ... 294 comments is a dictionary mapping between element path & SourceCodeInfo.Location 295 (contains information about source comments). 296 ''' 297 self.path = path 298 self.index = index 299 self.comments = comments 300 301 def element_path(self): 302 '''Get path to proto element.''' 303 return [self.path, self.index] 304 305 def member_path(self, member_index): 306 '''Get path to member of proto element. 307 Example paths: 308 [4, m] - message comments, m: msgIdx in proto from 0 309 [4, m, 2, f] - field comments in message, f: fieldIdx in message from 0 310 [6, s] - service comments, s: svcIdx in proto from 0 311 [6, s, 2, r] - rpc comments in service, r: rpc method def in service from 0 312 ''' 313 return self.element_path() + [FIELD_PATH, member_index] 314 315 def get_comments(self, path, leading_indent=True): 316 '''Get leading & trailing comments for enum member based on path. 317 318 path is the proto path of an element or member (ex. [5 0] or [4 1 2 0]) 319 leading_indent is a flag that indicates if leading comments should be indented 320 ''' 321 322 # Obtain SourceCodeInfo.Location object containing comment 323 # information (based on the member path) 324 comment = self.comments.get(str(path)) 325 326 leading_comment = "" 327 trailing_comment = "" 328 329 if not comment: 330 return leading_comment, trailing_comment 331 332 if comment.leading_comments: 333 leading_comment = " " if leading_indent else "" 334 leading_comment += "/* %s */" % comment.leading_comments.strip() 335 336 if comment.trailing_comments: 337 trailing_comment = "/* %s */" % comment.trailing_comments.strip() 338 339 return leading_comment, trailing_comment 340 341 342class Enum(ProtoElement): 343 def __init__(self, names, desc, enum_options, index, comments): 344 ''' 345 desc is EnumDescriptorProto 346 index is the index of this enum element inside the file 347 comments is a dictionary mapping between element path & SourceCodeInfo.Location 348 (contains information about source comments) 349 ''' 350 super(Enum, self).__init__(ENUM_PATH, index, comments) 351 352 self.options = enum_options 353 self.names = names 354 355 # by definition, `names` include this enum's name 356 base_name = Names(names.parts[:-1]) 357 358 if enum_options.long_names: 359 self.values = [(names + x.name, x.number) for x in desc.value] 360 else: 361 self.values = [(base_name + x.name, x.number) for x in desc.value] 362 363 self.value_longnames = [self.names + x.name for x in desc.value] 364 self.packed = enum_options.packed_enum 365 366 def has_negative(self): 367 for n, v in self.values: 368 if v < 0: 369 return True 370 return False 371 372 def encoded_size(self): 373 return max([varint_max_size(v) for n,v in self.values]) 374 375 def __str__(self): 376 enum_path = self.element_path() 377 leading_comment, trailing_comment = self.get_comments(enum_path, leading_indent=False) 378 379 result = '' 380 if leading_comment: 381 result = '%s\n' % leading_comment 382 383 result += 'typedef enum _%s { %s\n' % (self.names, trailing_comment) 384 385 enum_length = len(self.values) 386 enum_values = [] 387 for index, (name, value) in enumerate(self.values): 388 member_path = self.member_path(index) 389 leading_comment, trailing_comment = self.get_comments(member_path) 390 391 if leading_comment: 392 enum_values.append(leading_comment) 393 394 comma = "," 395 if index == enum_length - 1: 396 # last enum member should not end with a comma 397 comma = "" 398 399 enum_values.append(" %s = %d%s %s" % (name, value, comma, trailing_comment)) 400 401 result += '\n'.join(enum_values) 402 result += '\n}' 403 404 if self.packed: 405 result += ' pb_packed' 406 407 result += ' %s;' % self.names 408 return result 409 410 def auxiliary_defines(self): 411 # sort the enum by value 412 sorted_values = sorted(self.values, key = lambda x: (x[1], x[0])) 413 result = '#define _%s_MIN %s\n' % (self.names, sorted_values[0][0]) 414 result += '#define _%s_MAX %s\n' % (self.names, sorted_values[-1][0]) 415 result += '#define _%s_ARRAYSIZE ((%s)(%s+1))\n' % (self.names, self.names, sorted_values[-1][0]) 416 417 if not self.options.long_names: 418 # Define the long names always so that enum value references 419 # from other files work properly. 420 for i, x in enumerate(self.values): 421 result += '#define %s %s\n' % (self.value_longnames[i], x[0]) 422 423 if self.options.enum_to_string: 424 result += 'const char *%s_name(%s v);\n' % (self.names, self.names) 425 426 return result 427 428 def enum_to_string_definition(self): 429 if not self.options.enum_to_string: 430 return "" 431 432 result = 'const char *%s_name(%s v) {\n' % (self.names, self.names) 433 result += ' switch (v) {\n' 434 435 for ((enumname, _), strname) in zip(self.values, self.value_longnames): 436 # Strip off the leading type name from the string value. 437 strval = str(strname)[len(str(self.names)) + 1:] 438 result += ' case %s: return "%s";\n' % (enumname, strval) 439 440 result += ' }\n' 441 result += ' return "unknown";\n' 442 result += '}\n' 443 444 return result 445 446class FieldMaxSize: 447 def __init__(self, worst = 0, checks = [], field_name = 'undefined'): 448 if isinstance(worst, list): 449 self.worst = max(i for i in worst if i is not None) 450 else: 451 self.worst = worst 452 453 self.worst_field = field_name 454 self.checks = list(checks) 455 456 def extend(self, extend, field_name = None): 457 self.worst = max(self.worst, extend.worst) 458 459 if self.worst == extend.worst: 460 self.worst_field = extend.worst_field 461 462 self.checks.extend(extend.checks) 463 464class Field: 465 macro_x_param = 'X' 466 macro_a_param = 'a' 467 468 def __init__(self, struct_name, desc, field_options): 469 '''desc is FieldDescriptorProto''' 470 self.tag = desc.number 471 self.struct_name = struct_name 472 self.union_name = None 473 self.name = desc.name 474 self.default = None 475 self.max_size = None 476 self.max_count = None 477 self.array_decl = "" 478 self.enc_size = None 479 self.data_item_size = None 480 self.ctype = None 481 self.fixed_count = False 482 self.callback_datatype = field_options.callback_datatype 483 self.math_include_required = False 484 self.sort_by_tag = field_options.sort_by_tag 485 486 if field_options.type == nanopb_pb2.FT_INLINE: 487 # Before nanopb-0.3.8, fixed length bytes arrays were specified 488 # by setting type to FT_INLINE. But to handle pointer typed fields, 489 # it makes sense to have it as a separate option. 490 field_options.type = nanopb_pb2.FT_STATIC 491 field_options.fixed_length = True 492 493 # Parse field options 494 if field_options.HasField("max_size"): 495 self.max_size = field_options.max_size 496 497 self.default_has = field_options.default_has 498 499 if desc.type == FieldD.TYPE_STRING and field_options.HasField("max_length"): 500 # max_length overrides max_size for strings 501 self.max_size = field_options.max_length + 1 502 503 if field_options.HasField("max_count"): 504 self.max_count = field_options.max_count 505 506 if desc.HasField('default_value'): 507 self.default = desc.default_value 508 509 # Check field rules, i.e. required/optional/repeated. 510 can_be_static = True 511 if desc.label == FieldD.LABEL_REPEATED: 512 self.rules = 'REPEATED' 513 if self.max_count is None: 514 can_be_static = False 515 else: 516 self.array_decl = '[%d]' % self.max_count 517 if field_options.fixed_count: 518 self.rules = 'FIXARRAY' 519 520 elif field_options.proto3: 521 if desc.type == FieldD.TYPE_MESSAGE and not field_options.proto3_singular_msgs: 522 # In most other protobuf libraries proto3 submessages have 523 # "null" status. For nanopb, that is implemented as has_ field. 524 self.rules = 'OPTIONAL' 525 elif hasattr(desc, "proto3_optional") and desc.proto3_optional: 526 # Protobuf 3.12 introduced optional fields for proto3 syntax 527 self.rules = 'OPTIONAL' 528 else: 529 # Proto3 singular fields (without has_field) 530 self.rules = 'SINGULAR' 531 elif desc.label == FieldD.LABEL_REQUIRED: 532 self.rules = 'REQUIRED' 533 elif desc.label == FieldD.LABEL_OPTIONAL: 534 self.rules = 'OPTIONAL' 535 else: 536 raise NotImplementedError(desc.label) 537 538 # Check if the field can be implemented with static allocation 539 # i.e. whether the data size is known. 540 if desc.type == FieldD.TYPE_STRING and self.max_size is None: 541 can_be_static = False 542 543 if desc.type == FieldD.TYPE_BYTES and self.max_size is None: 544 can_be_static = False 545 546 # Decide how the field data will be allocated 547 if field_options.type == nanopb_pb2.FT_DEFAULT: 548 if can_be_static: 549 field_options.type = nanopb_pb2.FT_STATIC 550 else: 551 field_options.type = nanopb_pb2.FT_CALLBACK 552 553 if field_options.type == nanopb_pb2.FT_STATIC and not can_be_static: 554 raise Exception("Field '%s' is defined as static, but max_size or " 555 "max_count is not given." % self.name) 556 557 if field_options.fixed_count and self.max_count is None: 558 raise Exception("Field '%s' is defined as fixed count, " 559 "but max_count is not given." % self.name) 560 561 if field_options.type == nanopb_pb2.FT_STATIC: 562 self.allocation = 'STATIC' 563 elif field_options.type == nanopb_pb2.FT_POINTER: 564 self.allocation = 'POINTER' 565 elif field_options.type == nanopb_pb2.FT_CALLBACK: 566 self.allocation = 'CALLBACK' 567 else: 568 raise NotImplementedError(field_options.type) 569 570 if field_options.HasField("type_override"): 571 desc.type = field_options.type_override 572 573 # Decide the C data type to use in the struct. 574 if desc.type in datatypes: 575 self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[desc.type] 576 577 # Override the field size if user wants to use smaller integers 578 if (desc.type, field_options.int_size) in datatypes: 579 self.ctype, self.pbtype, self.enc_size, self.data_item_size = datatypes[(desc.type, field_options.int_size)] 580 elif desc.type == FieldD.TYPE_ENUM: 581 self.pbtype = 'ENUM' 582 self.data_item_size = 4 583 self.ctype = names_from_type_name(desc.type_name) 584 if self.default is not None: 585 self.default = self.ctype + self.default 586 self.enc_size = None # Needs to be filled in when enum values are known 587 elif desc.type == FieldD.TYPE_STRING: 588 self.pbtype = 'STRING' 589 self.ctype = 'char' 590 if self.allocation == 'STATIC': 591 self.ctype = 'char' 592 self.array_decl += '[%d]' % self.max_size 593 # -1 because of null terminator. Both pb_encode and pb_decode 594 # check the presence of it. 595 self.enc_size = varint_max_size(self.max_size) + self.max_size - 1 596 elif desc.type == FieldD.TYPE_BYTES: 597 if field_options.fixed_length: 598 self.pbtype = 'FIXED_LENGTH_BYTES' 599 600 if self.max_size is None: 601 raise Exception("Field '%s' is defined as fixed length, " 602 "but max_size is not given." % self.name) 603 604 self.enc_size = varint_max_size(self.max_size) + self.max_size 605 self.ctype = 'pb_byte_t' 606 self.array_decl += '[%d]' % self.max_size 607 else: 608 self.pbtype = 'BYTES' 609 self.ctype = 'pb_bytes_array_t' 610 if self.allocation == 'STATIC': 611 self.ctype = self.struct_name + self.name + 't' 612 self.enc_size = varint_max_size(self.max_size) + self.max_size 613 elif desc.type == FieldD.TYPE_MESSAGE: 614 self.pbtype = 'MESSAGE' 615 self.ctype = self.submsgname = names_from_type_name(desc.type_name) 616 self.enc_size = None # Needs to be filled in after the message type is available 617 if field_options.submsg_callback and self.allocation == 'STATIC': 618 self.pbtype = 'MSG_W_CB' 619 else: 620 raise NotImplementedError(desc.type) 621 622 if self.default and self.pbtype in ['FLOAT', 'DOUBLE']: 623 if 'inf' in self.default or 'nan' in self.default: 624 self.math_include_required = True 625 626 def __lt__(self, other): 627 return self.tag < other.tag 628 629 def __str__(self): 630 result = '' 631 if self.allocation == 'POINTER': 632 if self.rules == 'REPEATED': 633 if self.pbtype == 'MSG_W_CB': 634 result += ' pb_callback_t cb_' + self.name + ';\n' 635 result += ' pb_size_t ' + self.name + '_count;\n' 636 637 if self.pbtype in ['MESSAGE', 'MSG_W_CB']: 638 # Use struct definition, so recursive submessages are possible 639 result += ' struct _%s *%s;' % (self.ctype, self.name) 640 elif self.pbtype == 'FIXED_LENGTH_BYTES' or self.rules == 'FIXARRAY': 641 # Pointer to fixed size array 642 result += ' %s (*%s)%s;' % (self.ctype, self.name, self.array_decl) 643 elif self.rules in ['REPEATED', 'FIXARRAY'] and self.pbtype in ['STRING', 'BYTES']: 644 # String/bytes arrays need to be defined as pointers to pointers 645 result += ' %s **%s;' % (self.ctype, self.name) 646 else: 647 result += ' %s *%s;' % (self.ctype, self.name) 648 elif self.allocation == 'CALLBACK': 649 result += ' %s %s;' % (self.callback_datatype, self.name) 650 else: 651 if self.pbtype == 'MSG_W_CB' and self.rules in ['OPTIONAL', 'REPEATED']: 652 result += ' pb_callback_t cb_' + self.name + ';\n' 653 654 if self.rules == 'OPTIONAL': 655 result += ' bool has_' + self.name + ';\n' 656 elif self.rules == 'REPEATED': 657 result += ' pb_size_t ' + self.name + '_count;\n' 658 result += ' %s %s%s;' % (self.ctype, self.name, self.array_decl) 659 return result 660 661 def types(self): 662 '''Return definitions for any special types this field might need.''' 663 if self.pbtype == 'BYTES' and self.allocation == 'STATIC': 664 result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype) 665 else: 666 result = '' 667 return result 668 669 def get_dependencies(self): 670 '''Get list of type names used by this field.''' 671 if self.allocation == 'STATIC': 672 return [str(self.ctype)] 673 else: 674 return [] 675 676 def get_initializer(self, null_init, inner_init_only = False): 677 '''Return literal expression for this field's default value. 678 null_init: If True, initialize to a 0 value instead of default from .proto 679 inner_init_only: If True, exclude initialization for any count/has fields 680 ''' 681 682 inner_init = None 683 if self.pbtype in ['MESSAGE', 'MSG_W_CB']: 684 if null_init: 685 inner_init = '%s_init_zero' % self.ctype 686 else: 687 inner_init = '%s_init_default' % self.ctype 688 elif self.default is None or null_init: 689 if self.pbtype == 'STRING': 690 inner_init = '""' 691 elif self.pbtype == 'BYTES': 692 inner_init = '{0, {0}}' 693 elif self.pbtype == 'FIXED_LENGTH_BYTES': 694 inner_init = '{0}' 695 elif self.pbtype in ('ENUM', 'UENUM'): 696 inner_init = '_%s_MIN' % self.ctype 697 else: 698 inner_init = '0' 699 else: 700 if self.pbtype == 'STRING': 701 data = codecs.escape_encode(self.default.encode('utf-8'))[0] 702 inner_init = '"' + data.decode('ascii') + '"' 703 elif self.pbtype == 'BYTES': 704 data = codecs.escape_decode(self.default)[0] 705 data = ["0x%02x" % c for c in bytearray(data)] 706 if len(data) == 0: 707 inner_init = '{0, {0}}' 708 else: 709 inner_init = '{%d, {%s}}' % (len(data), ','.join(data)) 710 elif self.pbtype == 'FIXED_LENGTH_BYTES': 711 data = codecs.escape_decode(self.default)[0] 712 data = ["0x%02x" % c for c in bytearray(data)] 713 if len(data) == 0: 714 inner_init = '{0}' 715 else: 716 inner_init = '{%s}' % ','.join(data) 717 elif self.pbtype in ['FIXED32', 'UINT32']: 718 inner_init = str(self.default) + 'u' 719 elif self.pbtype in ['FIXED64', 'UINT64']: 720 inner_init = str(self.default) + 'ull' 721 elif self.pbtype in ['SFIXED64', 'INT64']: 722 inner_init = str(self.default) + 'll' 723 elif self.pbtype in ['FLOAT', 'DOUBLE']: 724 inner_init = str(self.default) 725 if 'inf' in inner_init: 726 inner_init = inner_init.replace('inf', 'INFINITY') 727 elif 'nan' in inner_init: 728 inner_init = inner_init.replace('nan', 'NAN') 729 elif (not '.' in inner_init) and self.pbtype == 'FLOAT': 730 inner_init += '.0f' 731 elif self.pbtype == 'FLOAT': 732 inner_init += 'f' 733 else: 734 inner_init = str(self.default) 735 736 if inner_init_only: 737 return inner_init 738 739 outer_init = None 740 if self.allocation == 'STATIC': 741 if self.rules == 'REPEATED': 742 outer_init = '0, {' + ', '.join([inner_init] * self.max_count) + '}' 743 elif self.rules == 'FIXARRAY': 744 outer_init = '{' + ', '.join([inner_init] * self.max_count) + '}' 745 elif self.rules == 'OPTIONAL': 746 if null_init or not self.default_has: 747 outer_init = 'false, ' + inner_init 748 else: 749 outer_init = 'true, ' + inner_init 750 else: 751 outer_init = inner_init 752 elif self.allocation == 'POINTER': 753 if self.rules == 'REPEATED': 754 outer_init = '0, NULL' 755 else: 756 outer_init = 'NULL' 757 elif self.allocation == 'CALLBACK': 758 if self.pbtype == 'EXTENSION': 759 outer_init = 'NULL' 760 else: 761 outer_init = '{{NULL}, NULL}' 762 763 if self.pbtype == 'MSG_W_CB' and self.rules in ['REPEATED', 'OPTIONAL']: 764 outer_init = '{{NULL}, NULL}, ' + outer_init 765 766 return outer_init 767 768 def tags(self): 769 '''Return the #define for the tag number of this field.''' 770 identifier = '%s_%s_tag' % (self.struct_name, self.name) 771 return '#define %-40s %d\n' % (identifier, self.tag) 772 773 def fieldlist(self): 774 '''Return the FIELDLIST macro entry for this field. 775 Format is: X(a, ATYPE, HTYPE, LTYPE, field_name, tag) 776 ''' 777 name = self.name 778 779 if self.rules == "ONEOF": 780 # For oneofs, make a tuple of the union name, union member name, 781 # and the name inside the parent struct. 782 if not self.anonymous: 783 name = '(%s,%s,%s)' % (self.union_name, self.name, self.union_name + '.' + self.name) 784 else: 785 name = '(%s,%s,%s)' % (self.union_name, self.name, self.name) 786 787 return '%s(%s, %-9s %-9s %-9s %-16s %3d)' % (self.macro_x_param, 788 self.macro_a_param, 789 self.allocation + ',', 790 self.rules + ',', 791 self.pbtype + ',', 792 name + ',', 793 self.tag) 794 795 def data_size(self, dependencies): 796 '''Return estimated size of this field in the C struct. 797 This is used to try to automatically pick right descriptor size. 798 If the estimate is wrong, it will result in compile time error and 799 user having to specify descriptor_width option. 800 ''' 801 if self.allocation == 'POINTER' or self.pbtype == 'EXTENSION': 802 size = 8 803 alignment = 8 804 elif self.allocation == 'CALLBACK': 805 size = 16 806 alignment = 8 807 elif self.pbtype in ['MESSAGE', 'MSG_W_CB']: 808 alignment = 8 809 if str(self.submsgname) in dependencies: 810 other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name)) 811 size = dependencies[str(self.submsgname)].data_size(other_dependencies) 812 else: 813 size = 256 # Message is in other file, this is reasonable guess for most cases 814 815 if self.pbtype == 'MSG_W_CB': 816 size += 16 817 elif self.pbtype in ['STRING', 'FIXED_LENGTH_BYTES']: 818 size = self.max_size 819 alignment = 4 820 elif self.pbtype == 'BYTES': 821 size = self.max_size + 4 822 alignment = 4 823 elif self.data_item_size is not None: 824 size = self.data_item_size 825 alignment = 4 826 if self.data_item_size >= 8: 827 alignment = 8 828 else: 829 raise Exception("Unhandled field type: %s" % self.pbtype) 830 831 if self.rules in ['REPEATED', 'FIXARRAY'] and self.allocation == 'STATIC': 832 size *= self.max_count 833 834 if self.rules not in ('REQUIRED', 'SINGULAR'): 835 size += 4 836 837 if size % alignment != 0: 838 # Estimate how much alignment requirements will increase the size. 839 size += alignment - (size % alignment) 840 841 return size 842 843 def encoded_size(self, dependencies): 844 '''Return the maximum size that this field can take when encoded, 845 including the field tag. If the size cannot be determined, returns 846 None.''' 847 848 if self.allocation != 'STATIC': 849 return None 850 851 if self.pbtype in ['MESSAGE', 'MSG_W_CB']: 852 encsize = None 853 if str(self.submsgname) in dependencies: 854 submsg = dependencies[str(self.submsgname)] 855 other_dependencies = dict(x for x in dependencies.items() if x[0] != str(self.struct_name)) 856 encsize = submsg.encoded_size(other_dependencies) 857 858 my_msg = dependencies.get(str(self.struct_name)) 859 external = (not my_msg or submsg.protofile != my_msg.protofile) 860 861 if encsize and encsize.symbols and external: 862 # Couldn't fully resolve the size of a dependency from 863 # another file. Instead of including the symbols directly, 864 # just use the #define SubMessage_size from the header. 865 encsize = None 866 867 if encsize is not None: 868 # Include submessage length prefix 869 encsize += varint_max_size(encsize.upperlimit()) 870 elif not external: 871 # The dependency is from the same file and size cannot be 872 # determined for it, thus we know it will not be possible 873 # in runtime either. 874 return None 875 876 if encsize is None: 877 # Submessage or its size cannot be found. 878 # This can occur if submessage is defined in different 879 # file, and it or its .options could not be found. 880 # Instead of direct numeric value, reference the size that 881 # has been #defined in the other file. 882 encsize = EncodedSize(self.submsgname + 'size') 883 884 # We will have to make a conservative assumption on the length 885 # prefix size, though. 886 encsize += 5 887 888 elif self.pbtype in ['ENUM', 'UENUM']: 889 if str(self.ctype) in dependencies: 890 enumtype = dependencies[str(self.ctype)] 891 encsize = enumtype.encoded_size() 892 else: 893 # Conservative assumption 894 encsize = 10 895 896 elif self.enc_size is None: 897 raise RuntimeError("Could not determine encoded size for %s.%s" 898 % (self.struct_name, self.name)) 899 else: 900 encsize = EncodedSize(self.enc_size) 901 902 encsize += varint_max_size(self.tag << 3) # Tag + wire type 903 904 if self.rules in ['REPEATED', 'FIXARRAY']: 905 # Decoders must be always able to handle unpacked arrays. 906 # Therefore we have to reserve space for it, even though 907 # we emit packed arrays ourselves. For length of 1, packed 908 # arrays are larger however so we need to add allowance 909 # for the length byte. 910 encsize *= self.max_count 911 912 if self.max_count == 1: 913 encsize += 1 914 915 return encsize 916 917 def has_callbacks(self): 918 return self.allocation == 'CALLBACK' 919 920 def requires_custom_field_callback(self): 921 return self.allocation == 'CALLBACK' and self.callback_datatype != 'pb_callback_t' 922 923class ExtensionRange(Field): 924 def __init__(self, struct_name, range_start, field_options): 925 '''Implements a special pb_extension_t* field in an extensible message 926 structure. The range_start signifies the index at which the extensions 927 start. Not necessarily all tags above this are extensions, it is merely 928 a speed optimization. 929 ''' 930 self.tag = range_start 931 self.struct_name = struct_name 932 self.name = 'extensions' 933 self.pbtype = 'EXTENSION' 934 self.rules = 'OPTIONAL' 935 self.allocation = 'CALLBACK' 936 self.ctype = 'pb_extension_t' 937 self.array_decl = '' 938 self.default = None 939 self.max_size = 0 940 self.max_count = 0 941 self.data_item_size = 0 942 self.fixed_count = False 943 self.callback_datatype = 'pb_extension_t*' 944 945 def requires_custom_field_callback(self): 946 return False 947 948 def __str__(self): 949 return ' pb_extension_t *extensions;' 950 951 def types(self): 952 return '' 953 954 def tags(self): 955 return '' 956 957 def encoded_size(self, dependencies): 958 # We exclude extensions from the count, because they cannot be known 959 # until runtime. Other option would be to return None here, but this 960 # way the value remains useful if extensions are not used. 961 return EncodedSize(0) 962 963class ExtensionField(Field): 964 def __init__(self, fullname, desc, field_options): 965 self.fullname = fullname 966 self.extendee_name = names_from_type_name(desc.extendee) 967 Field.__init__(self, self.fullname + "extmsg", desc, field_options) 968 969 if self.rules != 'OPTIONAL': 970 self.skip = True 971 else: 972 self.skip = False 973 self.rules = 'REQUIRED' # We don't really want the has_field for extensions 974 # currently no support for comments for extension fields => provide 0, {} 975 self.msg = Message(self.fullname + "extmsg", None, field_options, 0, {}) 976 self.msg.fields.append(self) 977 978 def tags(self): 979 '''Return the #define for the tag number of this field.''' 980 identifier = '%s_tag' % self.fullname 981 return '#define %-40s %d\n' % (identifier, self.tag) 982 983 def extension_decl(self): 984 '''Declaration of the extension type in the .pb.h file''' 985 if self.skip: 986 msg = '/* Extension field %s was skipped because only "optional"\n' % self.fullname 987 msg +=' type of extension fields is currently supported. */\n' 988 return msg 989 990 return ('extern const pb_extension_type_t %s; /* field type: %s */\n' % 991 (self.fullname, str(self).strip())) 992 993 def extension_def(self, dependencies): 994 '''Definition of the extension type in the .pb.c file''' 995 996 if self.skip: 997 return '' 998 999 result = "/* Definition for extension field %s */\n" % self.fullname 1000 result += str(self.msg) 1001 result += self.msg.fields_declaration(dependencies) 1002 result += 'pb_byte_t %s_default[] = {0x00};\n' % self.msg.name 1003 result += self.msg.fields_definition(dependencies) 1004 result += 'const pb_extension_type_t %s = {\n' % self.fullname 1005 result += ' NULL,\n' 1006 result += ' NULL,\n' 1007 result += ' &%s_msg\n' % self.msg.name 1008 result += '};\n' 1009 return result 1010 1011 1012# --------------------------------------------------------------------------- 1013# Generation of oneofs (unions) 1014# --------------------------------------------------------------------------- 1015 1016class OneOf(Field): 1017 def __init__(self, struct_name, oneof_desc, oneof_options): 1018 self.struct_name = struct_name 1019 self.name = oneof_desc.name 1020 self.ctype = 'union' 1021 self.pbtype = 'oneof' 1022 self.fields = [] 1023 self.allocation = 'ONEOF' 1024 self.default = None 1025 self.rules = 'ONEOF' 1026 self.anonymous = oneof_options.anonymous_oneof 1027 self.sort_by_tag = oneof_options.sort_by_tag 1028 self.has_msg_cb = False 1029 1030 def add_field(self, field): 1031 field.union_name = self.name 1032 field.rules = 'ONEOF' 1033 field.anonymous = self.anonymous 1034 self.fields.append(field) 1035 1036 if self.sort_by_tag: 1037 self.fields.sort() 1038 1039 if field.pbtype == 'MSG_W_CB': 1040 self.has_msg_cb = True 1041 1042 # Sort by the lowest tag number inside union 1043 self.tag = min([f.tag for f in self.fields]) 1044 1045 def __str__(self): 1046 result = '' 1047 if self.fields: 1048 if self.has_msg_cb: 1049 result += ' pb_callback_t cb_' + self.name + ';\n' 1050 1051 result += ' pb_size_t which_' + self.name + ";\n" 1052 result += ' union {\n' 1053 for f in self.fields: 1054 result += ' ' + str(f).replace('\n', '\n ') + '\n' 1055 if self.anonymous: 1056 result += ' };' 1057 else: 1058 result += ' } ' + self.name + ';' 1059 return result 1060 1061 def types(self): 1062 return ''.join([f.types() for f in self.fields]) 1063 1064 def get_dependencies(self): 1065 deps = [] 1066 for f in self.fields: 1067 deps += f.get_dependencies() 1068 return deps 1069 1070 def get_initializer(self, null_init): 1071 if self.has_msg_cb: 1072 return '{{NULL}, NULL}, 0, {' + self.fields[0].get_initializer(null_init) + '}' 1073 else: 1074 return '0, {' + self.fields[0].get_initializer(null_init) + '}' 1075 1076 def tags(self): 1077 return ''.join([f.tags() for f in self.fields]) 1078 1079 def data_size(self, dependencies): 1080 return max(f.data_size(dependencies) for f in self.fields) 1081 1082 def encoded_size(self, dependencies): 1083 '''Returns the size of the largest oneof field.''' 1084 largest = 0 1085 dynamic_sizes = {} 1086 for f in self.fields: 1087 size = EncodedSize(f.encoded_size(dependencies)) 1088 if size is None or size.value is None: 1089 return None 1090 elif size.symbols: 1091 dynamic_sizes[f.tag] = size 1092 elif size.value > largest: 1093 largest = size.value 1094 1095 if not dynamic_sizes: 1096 # Simple case, all sizes were known at generator time 1097 return EncodedSize(largest) 1098 1099 if largest > 0: 1100 # Some sizes were known, some were not 1101 dynamic_sizes[0] = EncodedSize(largest) 1102 1103 # Couldn't find size for submessage at generation time, 1104 # have to rely on macro resolution at compile time. 1105 if len(dynamic_sizes) == 1: 1106 # Only one symbol was needed 1107 return list(dynamic_sizes.values())[0] 1108 else: 1109 # Use sizeof(union{}) construct to find the maximum size of 1110 # submessages. 1111 union_name = "%s_%s_size_union" % (self.struct_name, self.name) 1112 union_def = 'union %s {%s};\n' % (union_name, ' '.join('char f%d[%s];' % (k, s) for k,s in dynamic_sizes.items())) 1113 required_defs = list(itertools.chain.from_iterable(s.required_defines for k,s in dynamic_sizes.items())) 1114 return EncodedSize(0, ['sizeof(union %s)' % union_name], [union_def], required_defs) 1115 1116 def has_callbacks(self): 1117 return bool([f for f in self.fields if f.has_callbacks()]) 1118 1119 def requires_custom_field_callback(self): 1120 return bool([f for f in self.fields if f.requires_custom_field_callback()]) 1121 1122# --------------------------------------------------------------------------- 1123# Generation of messages (structures) 1124# --------------------------------------------------------------------------- 1125 1126 1127class Message(ProtoElement): 1128 def __init__(self, names, desc, message_options, index, comments): 1129 super(Message, self).__init__(MESSAGE_PATH, index, comments) 1130 self.name = names 1131 self.fields = [] 1132 self.oneofs = {} 1133 self.desc = desc 1134 self.math_include_required = False 1135 self.packed = message_options.packed_struct 1136 self.descriptorsize = message_options.descriptorsize 1137 1138 if message_options.msgid: 1139 self.msgid = message_options.msgid 1140 1141 if desc is not None: 1142 self.load_fields(desc, message_options) 1143 1144 self.callback_function = message_options.callback_function 1145 if not message_options.HasField('callback_function'): 1146 # Automatically assign a per-message callback if any field has 1147 # a special callback_datatype. 1148 for field in self.fields: 1149 if field.requires_custom_field_callback(): 1150 self.callback_function = "%s_callback" % self.name 1151 break 1152 1153 def load_fields(self, desc, message_options): 1154 '''Load field list from DescriptorProto''' 1155 1156 no_unions = [] 1157 1158 if hasattr(desc, 'oneof_decl'): 1159 for i, f in enumerate(desc.oneof_decl): 1160 oneof_options = get_nanopb_suboptions(desc, message_options, self.name + f.name) 1161 if oneof_options.no_unions: 1162 no_unions.append(i) # No union, but add fields normally 1163 elif oneof_options.type == nanopb_pb2.FT_IGNORE: 1164 pass # No union and skip fields also 1165 else: 1166 oneof = OneOf(self.name, f, oneof_options) 1167 self.oneofs[i] = oneof 1168 else: 1169 sys.stderr.write('Note: This Python protobuf library has no OneOf support\n') 1170 1171 for f in desc.field: 1172 field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) 1173 if field_options.type == nanopb_pb2.FT_IGNORE: 1174 continue 1175 1176 if field_options.descriptorsize > self.descriptorsize: 1177 self.descriptorsize = field_options.descriptorsize 1178 1179 field = Field(self.name, f, field_options) 1180 if hasattr(f, 'oneof_index') and f.HasField('oneof_index'): 1181 if hasattr(f, 'proto3_optional') and f.proto3_optional: 1182 no_unions.append(f.oneof_index) 1183 1184 if f.oneof_index in no_unions: 1185 self.fields.append(field) 1186 elif f.oneof_index in self.oneofs: 1187 self.oneofs[f.oneof_index].add_field(field) 1188 1189 if self.oneofs[f.oneof_index] not in self.fields: 1190 self.fields.append(self.oneofs[f.oneof_index]) 1191 else: 1192 self.fields.append(field) 1193 1194 if field.math_include_required: 1195 self.math_include_required = True 1196 1197 if len(desc.extension_range) > 0: 1198 field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') 1199 range_start = min([r.start for r in desc.extension_range]) 1200 if field_options.type != nanopb_pb2.FT_IGNORE: 1201 self.fields.append(ExtensionRange(self.name, range_start, field_options)) 1202 1203 if message_options.sort_by_tag: 1204 self.fields.sort() 1205 1206 def get_dependencies(self): 1207 '''Get list of type names that this structure refers to.''' 1208 deps = [] 1209 for f in self.fields: 1210 deps += f.get_dependencies() 1211 return deps 1212 1213 def __str__(self): 1214 message_path = self.element_path() 1215 leading_comment, trailing_comment = self.get_comments(message_path, leading_indent=False) 1216 1217 result = '' 1218 if leading_comment: 1219 result = '%s\n' % leading_comment 1220 1221 result += 'typedef struct _%s { %s\n' % (self.name, trailing_comment) 1222 1223 if not self.fields: 1224 # Empty structs are not allowed in C standard. 1225 # Therefore add a dummy field if an empty message occurs. 1226 result += ' char dummy_field;' 1227 1228 msg_fields = [] 1229 for index, field in enumerate(self.fields): 1230 member_path = self.member_path(index) 1231 leading_comment, trailing_comment = self.get_comments(member_path) 1232 1233 if leading_comment: 1234 msg_fields.append(leading_comment) 1235 1236 msg_fields.append("%s %s" % (str(field), trailing_comment)) 1237 1238 result += '\n'.join(msg_fields) 1239 1240 if Globals.protoc_insertion_points: 1241 result += '\n/* @@protoc_insertion_point(struct:%s) */' % self.name 1242 1243 result += '\n}' 1244 1245 if self.packed: 1246 result += ' pb_packed' 1247 1248 result += ' %s;' % self.name 1249 1250 if self.packed: 1251 result = 'PB_PACKED_STRUCT_START\n' + result 1252 result += '\nPB_PACKED_STRUCT_END' 1253 1254 return result + '\n' 1255 1256 def types(self): 1257 return ''.join([f.types() for f in self.fields]) 1258 1259 def get_initializer(self, null_init): 1260 if not self.fields: 1261 return '{0}' 1262 1263 parts = [] 1264 for field in self.fields: 1265 parts.append(field.get_initializer(null_init)) 1266 return '{' + ', '.join(parts) + '}' 1267 1268 def count_required_fields(self): 1269 '''Returns number of required fields inside this message''' 1270 count = 0 1271 for f in self.fields: 1272 if not isinstance(f, OneOf): 1273 if f.rules == 'REQUIRED': 1274 count += 1 1275 return count 1276 1277 def all_fields(self): 1278 '''Iterate over all fields in this message, including nested OneOfs.''' 1279 for f in self.fields: 1280 if isinstance(f, OneOf): 1281 for f2 in f.fields: 1282 yield f2 1283 else: 1284 yield f 1285 1286 1287 def field_for_tag(self, tag): 1288 '''Given a tag number, return the Field instance.''' 1289 for field in self.all_fields(): 1290 if field.tag == tag: 1291 return field 1292 return None 1293 1294 def count_all_fields(self): 1295 '''Count the total number of fields in this message.''' 1296 count = 0 1297 for f in self.fields: 1298 if isinstance(f, OneOf): 1299 count += len(f.fields) 1300 else: 1301 count += 1 1302 return count 1303 1304 def fields_declaration(self, dependencies): 1305 '''Return X-macro declaration of all fields in this message.''' 1306 Field.macro_x_param = 'X' 1307 Field.macro_a_param = 'a' 1308 while any(field.name == Field.macro_x_param for field in self.all_fields()): 1309 Field.macro_x_param += '_' 1310 while any(field.name == Field.macro_a_param for field in self.all_fields()): 1311 Field.macro_a_param += '_' 1312 1313 # Field descriptor array must be sorted by tag number, pb_common.c relies on it. 1314 sorted_fields = list(self.all_fields()) 1315 sorted_fields.sort(key = lambda x: x.tag) 1316 1317 result = '#define %s_FIELDLIST(%s, %s) \\\n' % (self.name, 1318 Field.macro_x_param, 1319 Field.macro_a_param) 1320 result += ' \\\n'.join(x.fieldlist() for x in sorted_fields) 1321 result += '\n' 1322 1323 has_callbacks = bool([f for f in self.fields if f.has_callbacks()]) 1324 if has_callbacks: 1325 if self.callback_function != 'pb_default_field_callback': 1326 result += "extern bool %s(pb_istream_t *istream, pb_ostream_t *ostream, const pb_field_t *field);\n" % self.callback_function 1327 result += "#define %s_CALLBACK %s\n" % (self.name, self.callback_function) 1328 else: 1329 result += "#define %s_CALLBACK NULL\n" % self.name 1330 1331 defval = self.default_value(dependencies) 1332 if defval: 1333 hexcoded = ''.join("\\x%02x" % ord(defval[i:i+1]) for i in range(len(defval))) 1334 result += '#define %s_DEFAULT (const pb_byte_t*)"%s\\x00"\n' % (self.name, hexcoded) 1335 else: 1336 result += '#define %s_DEFAULT NULL\n' % self.name 1337 1338 for field in sorted_fields: 1339 if field.pbtype in ['MESSAGE', 'MSG_W_CB']: 1340 if field.rules == 'ONEOF': 1341 result += "#define %s_%s_%s_MSGTYPE %s\n" % (self.name, field.union_name, field.name, field.ctype) 1342 else: 1343 result += "#define %s_%s_MSGTYPE %s\n" % (self.name, field.name, field.ctype) 1344 1345 return result 1346 1347 def fields_declaration_cpp_lookup(self): 1348 result = 'template <>\n' 1349 result += 'struct MessageDescriptor<%s> {\n' % (self.name) 1350 result += ' static PB_INLINE_CONSTEXPR const pb_size_t fields_array_length = %d;\n' % (self.count_all_fields()) 1351 result += ' static inline const pb_msgdesc_t* fields() {\n' 1352 result += ' return &%s_msg;\n' % (self.name) 1353 result += ' }\n' 1354 result += '};' 1355 return result 1356 1357 def fields_definition(self, dependencies): 1358 '''Return the field descriptor definition that goes in .pb.c file.''' 1359 width = self.required_descriptor_width(dependencies) 1360 if width == 1: 1361 width = 'AUTO' 1362 1363 result = 'PB_BIND(%s, %s, %s)\n' % (self.name, self.name, width) 1364 return result 1365 1366 def required_descriptor_width(self, dependencies): 1367 '''Estimate how many words are necessary for each field descriptor.''' 1368 if self.descriptorsize != nanopb_pb2.DS_AUTO: 1369 return int(self.descriptorsize) 1370 1371 if not self.fields: 1372 return 1 1373 1374 max_tag = max(field.tag for field in self.all_fields()) 1375 max_offset = self.data_size(dependencies) 1376 max_arraysize = max((field.max_count or 0) for field in self.all_fields()) 1377 max_datasize = max(field.data_size(dependencies) for field in self.all_fields()) 1378 1379 if max_arraysize > 0xFFFF: 1380 return 8 1381 elif (max_tag > 0x3FF or max_offset > 0xFFFF or 1382 max_arraysize > 0x0FFF or max_datasize > 0x0FFF): 1383 return 4 1384 elif max_tag > 0x3F or max_offset > 0xFF: 1385 return 2 1386 else: 1387 # NOTE: Macro logic in pb.h ensures that width 1 will 1388 # be raised to 2 automatically for string/submsg fields 1389 # and repeated fields. Thus only tag and offset need to 1390 # be checked. 1391 return 1 1392 1393 def data_size(self, dependencies): 1394 '''Return approximate sizeof(struct) in the compiled code.''' 1395 return sum(f.data_size(dependencies) for f in self.fields) 1396 1397 def encoded_size(self, dependencies): 1398 '''Return the maximum size that this message can take when encoded. 1399 If the size cannot be determined, returns None. 1400 ''' 1401 size = EncodedSize(0) 1402 for field in self.fields: 1403 fsize = field.encoded_size(dependencies) 1404 if fsize is None: 1405 return None 1406 size += fsize 1407 1408 return size 1409 1410 def default_value(self, dependencies): 1411 '''Generate serialized protobuf message that contains the 1412 default values for optional fields.''' 1413 1414 if not self.desc: 1415 return b'' 1416 1417 if self.desc.options.map_entry: 1418 return b'' 1419 1420 optional_only = copy.deepcopy(self.desc) 1421 1422 # Remove fields without default values 1423 # The iteration is done in reverse order to avoid remove() messing up iteration. 1424 for field in reversed(list(optional_only.field)): 1425 field.ClearField(str('extendee')) 1426 parsed_field = self.field_for_tag(field.number) 1427 if parsed_field is None or parsed_field.allocation != 'STATIC': 1428 optional_only.field.remove(field) 1429 elif (field.label == FieldD.LABEL_REPEATED or 1430 field.type == FieldD.TYPE_MESSAGE): 1431 optional_only.field.remove(field) 1432 elif hasattr(field, 'oneof_index') and field.HasField('oneof_index'): 1433 optional_only.field.remove(field) 1434 elif field.type == FieldD.TYPE_ENUM: 1435 # The partial descriptor doesn't include the enum type 1436 # so we fake it with int64. 1437 enumname = names_from_type_name(field.type_name) 1438 try: 1439 enumtype = dependencies[str(enumname)] 1440 except KeyError: 1441 raise Exception("Could not find enum type %s while generating default values for %s.\n" % (enumname, self.name) 1442 + "Try passing all source files to generator at once, or use -I option.") 1443 1444 if field.HasField('default_value'): 1445 defvals = [v for n,v in enumtype.values if n.parts[-1] == field.default_value] 1446 else: 1447 # If no default is specified, the default is the first value. 1448 defvals = [v for n,v in enumtype.values] 1449 if defvals and defvals[0] != 0: 1450 field.type = FieldD.TYPE_INT64 1451 field.default_value = str(defvals[0]) 1452 field.ClearField(str('type_name')) 1453 else: 1454 optional_only.field.remove(field) 1455 elif not field.HasField('default_value'): 1456 optional_only.field.remove(field) 1457 1458 if len(optional_only.field) == 0: 1459 return b'' 1460 1461 optional_only.ClearField(str('oneof_decl')) 1462 optional_only.ClearField(str('nested_type')) 1463 optional_only.ClearField(str('extension')) 1464 optional_only.ClearField(str('enum_type')) 1465 desc = google.protobuf.descriptor.MakeDescriptor(optional_only) 1466 msg = reflection.MakeClass(desc)() 1467 1468 for field in optional_only.field: 1469 if field.type == FieldD.TYPE_STRING: 1470 setattr(msg, field.name, field.default_value) 1471 elif field.type == FieldD.TYPE_BYTES: 1472 setattr(msg, field.name, codecs.escape_decode(field.default_value)[0]) 1473 elif field.type in [FieldD.TYPE_FLOAT, FieldD.TYPE_DOUBLE]: 1474 setattr(msg, field.name, float(field.default_value)) 1475 elif field.type == FieldD.TYPE_BOOL: 1476 setattr(msg, field.name, field.default_value == 'true') 1477 else: 1478 setattr(msg, field.name, int(field.default_value)) 1479 1480 return msg.SerializeToString() 1481 1482 1483# --------------------------------------------------------------------------- 1484# Processing of entire .proto files 1485# --------------------------------------------------------------------------- 1486 1487def iterate_messages(desc, flatten = False, names = Names()): 1488 '''Recursively find all messages. For each, yield name, DescriptorProto.''' 1489 if hasattr(desc, 'message_type'): 1490 submsgs = desc.message_type 1491 else: 1492 submsgs = desc.nested_type 1493 1494 for submsg in submsgs: 1495 sub_names = names + submsg.name 1496 if flatten: 1497 yield Names(submsg.name), submsg 1498 else: 1499 yield sub_names, submsg 1500 1501 for x in iterate_messages(submsg, flatten, sub_names): 1502 yield x 1503 1504def iterate_extensions(desc, flatten = False, names = Names()): 1505 '''Recursively find all extensions. 1506 For each, yield name, FieldDescriptorProto. 1507 ''' 1508 for extension in desc.extension: 1509 yield names, extension 1510 1511 for subname, subdesc in iterate_messages(desc, flatten, names): 1512 for extension in subdesc.extension: 1513 yield subname, extension 1514 1515def toposort2(data): 1516 '''Topological sort. 1517 From http://code.activestate.com/recipes/577413-topological-sort/ 1518 This function is under the MIT license. 1519 ''' 1520 for k, v in list(data.items()): 1521 v.discard(k) # Ignore self dependencies 1522 extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys()) 1523 data.update(dict([(item, set()) for item in extra_items_in_deps])) 1524 while True: 1525 ordered = set(item for item,dep in list(data.items()) if not dep) 1526 if not ordered: 1527 break 1528 for item in sorted(ordered): 1529 yield item 1530 data = dict([(item, (dep - ordered)) for item,dep in list(data.items()) 1531 if item not in ordered]) 1532 assert not data, "A cyclic dependency exists amongst %r" % data 1533 1534def sort_dependencies(messages): 1535 '''Sort a list of Messages based on dependencies.''' 1536 dependencies = {} 1537 message_by_name = {} 1538 for message in messages: 1539 dependencies[str(message.name)] = set(message.get_dependencies()) 1540 message_by_name[str(message.name)] = message 1541 1542 for msgname in toposort2(dependencies): 1543 if msgname in message_by_name: 1544 yield message_by_name[msgname] 1545 1546def make_identifier(headername): 1547 '''Make #ifndef identifier that contains uppercase A-Z and digits 0-9''' 1548 result = "" 1549 for c in headername.upper(): 1550 if c.isalnum(): 1551 result += c 1552 else: 1553 result += '_' 1554 return result 1555 1556class ProtoFile: 1557 def __init__(self, fdesc, file_options): 1558 '''Takes a FileDescriptorProto and parses it.''' 1559 self.fdesc = fdesc 1560 self.file_options = file_options 1561 self.dependencies = {} 1562 self.math_include_required = False 1563 self.parse() 1564 for message in self.messages: 1565 if message.math_include_required: 1566 self.math_include_required = True 1567 break 1568 1569 # Some of types used in this file probably come from the file itself. 1570 # Thus it has implicit dependency on itself. 1571 self.add_dependency(self) 1572 1573 def parse(self): 1574 self.enums = [] 1575 self.messages = [] 1576 self.extensions = [] 1577 1578 mangle_names = self.file_options.mangle_names 1579 flatten = mangle_names == nanopb_pb2.M_FLATTEN 1580 strip_prefix = None 1581 replacement_prefix = None 1582 if mangle_names == nanopb_pb2.M_STRIP_PACKAGE: 1583 strip_prefix = "." + self.fdesc.package 1584 elif mangle_names == nanopb_pb2.M_PACKAGE_INITIALS: 1585 strip_prefix = "." + self.fdesc.package 1586 replacement_prefix = "" 1587 for part in self.fdesc.package.split("."): 1588 replacement_prefix += part[0] 1589 elif self.file_options.package: 1590 strip_prefix = "." + self.fdesc.package 1591 replacement_prefix = self.file_options.package 1592 1593 1594 def create_name(names): 1595 if mangle_names in (nanopb_pb2.M_NONE, nanopb_pb2.M_PACKAGE_INITIALS): 1596 return base_name + names 1597 if mangle_names == nanopb_pb2.M_STRIP_PACKAGE: 1598 return Names(names) 1599 single_name = names 1600 if isinstance(names, Names): 1601 single_name = names.parts[-1] 1602 return Names(single_name) 1603 1604 def mangle_field_typename(typename): 1605 if mangle_names == nanopb_pb2.M_FLATTEN: 1606 return "." + typename.split(".")[-1] 1607 if strip_prefix is not None and typename.startswith(strip_prefix): 1608 if replacement_prefix is not None: 1609 return "." + replacement_prefix + typename[len(strip_prefix):] 1610 else: 1611 return typename[len(strip_prefix):] 1612 if self.file_options.package: 1613 return "." + replacement_prefix + typename 1614 return typename 1615 1616 if replacement_prefix is not None: 1617 base_name = Names(replacement_prefix.split('.')) 1618 elif self.fdesc.package: 1619 base_name = Names(self.fdesc.package.split('.')) 1620 else: 1621 base_name = Names() 1622 1623 # process source code comment locations 1624 # ignores any locations that do not contain any comment information 1625 self.comment_locations = { 1626 str(list(location.path)): location 1627 for location in self.fdesc.source_code_info.location 1628 if location.leading_comments or location.leading_detached_comments or location.trailing_comments 1629 } 1630 1631 for index, enum in enumerate(self.fdesc.enum_type): 1632 name = create_name(enum.name) 1633 enum_options = get_nanopb_suboptions(enum, self.file_options, name) 1634 self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations)) 1635 1636 for index, (names, message) in enumerate(iterate_messages(self.fdesc, flatten)): 1637 name = create_name(names) 1638 message_options = get_nanopb_suboptions(message, self.file_options, name) 1639 1640 if message_options.skip_message: 1641 continue 1642 1643 message = copy.deepcopy(message) 1644 for field in message.field: 1645 if field.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM): 1646 field.type_name = mangle_field_typename(field.type_name) 1647 1648 self.messages.append(Message(name, message, message_options, index, self.comment_locations)) 1649 for index, enum in enumerate(message.enum_type): 1650 name = create_name(names + enum.name) 1651 enum_options = get_nanopb_suboptions(enum, message_options, name) 1652 self.enums.append(Enum(name, enum, enum_options, index, self.comment_locations)) 1653 1654 for names, extension in iterate_extensions(self.fdesc, flatten): 1655 name = create_name(names + extension.name) 1656 field_options = get_nanopb_suboptions(extension, self.file_options, name) 1657 1658 extension = copy.deepcopy(extension) 1659 if extension.type in (FieldD.TYPE_MESSAGE, FieldD.TYPE_ENUM): 1660 extension.type_name = mangle_field_typename(extension.type_name) 1661 1662 if field_options.type != nanopb_pb2.FT_IGNORE: 1663 self.extensions.append(ExtensionField(name, extension, field_options)) 1664 1665 def add_dependency(self, other): 1666 for enum in other.enums: 1667 self.dependencies[str(enum.names)] = enum 1668 enum.protofile = other 1669 1670 for msg in other.messages: 1671 self.dependencies[str(msg.name)] = msg 1672 msg.protofile = other 1673 1674 # Fix field default values where enum short names are used. 1675 for enum in other.enums: 1676 if not enum.options.long_names: 1677 for message in self.messages: 1678 for field in message.all_fields(): 1679 if field.default in enum.value_longnames: 1680 idx = enum.value_longnames.index(field.default) 1681 field.default = enum.values[idx][0] 1682 1683 # Fix field data types where enums have negative values. 1684 for enum in other.enums: 1685 if not enum.has_negative(): 1686 for message in self.messages: 1687 for field in message.all_fields(): 1688 if field.pbtype == 'ENUM' and field.ctype == enum.names: 1689 field.pbtype = 'UENUM' 1690 1691 def generate_header(self, includes, headername, options): 1692 '''Generate content for a header file. 1693 Generates strings, which should be concatenated and stored to file. 1694 ''' 1695 1696 yield '/* Automatically generated nanopb header */\n' 1697 if options.notimestamp: 1698 yield '/* Generated by %s */\n\n' % (nanopb_version) 1699 else: 1700 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 1701 1702 if self.fdesc.package: 1703 symbol = make_identifier(self.fdesc.package + '_' + headername) 1704 else: 1705 symbol = make_identifier(headername) 1706 yield '#ifndef PB_%s_INCLUDED\n' % symbol 1707 yield '#define PB_%s_INCLUDED\n' % symbol 1708 if self.math_include_required: 1709 yield '#include <math.h>\n' 1710 try: 1711 yield options.libformat % ('pb.h') 1712 except TypeError: 1713 # no %s specified - use whatever was passed in as options.libformat 1714 yield options.libformat 1715 yield '\n' 1716 1717 for incfile in self.file_options.include: 1718 # allow including system headers 1719 if (incfile.startswith('<')): 1720 yield '#include %s\n' % incfile 1721 else: 1722 yield options.genformat % incfile 1723 yield '\n' 1724 1725 for incfile in includes: 1726 noext = os.path.splitext(incfile)[0] 1727 yield options.genformat % (noext + options.extension + options.header_extension) 1728 yield '\n' 1729 1730 if Globals.protoc_insertion_points: 1731 yield '/* @@protoc_insertion_point(includes) */\n' 1732 1733 yield '\n' 1734 1735 yield '#if PB_PROTO_HEADER_VERSION != 40\n' 1736 yield '#error Regenerate this file with the current version of nanopb generator.\n' 1737 yield '#endif\n' 1738 yield '\n' 1739 1740 if self.enums: 1741 yield '/* Enum definitions */\n' 1742 for enum in self.enums: 1743 yield str(enum) + '\n\n' 1744 1745 if self.messages: 1746 yield '/* Struct definitions */\n' 1747 for msg in sort_dependencies(self.messages): 1748 yield msg.types() 1749 yield str(msg) + '\n' 1750 yield '\n' 1751 1752 if self.extensions: 1753 yield '/* Extensions */\n' 1754 for extension in self.extensions: 1755 yield extension.extension_decl() 1756 yield '\n' 1757 1758 if self.enums: 1759 yield '/* Helper constants for enums */\n' 1760 for enum in self.enums: 1761 yield enum.auxiliary_defines() + '\n' 1762 yield '\n' 1763 1764 yield '#ifdef __cplusplus\n' 1765 yield 'extern "C" {\n' 1766 yield '#endif\n\n' 1767 1768 if self.messages: 1769 yield '/* Initializer values for message structs */\n' 1770 for msg in self.messages: 1771 identifier = '%s_init_default' % msg.name 1772 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False)) 1773 for msg in self.messages: 1774 identifier = '%s_init_zero' % msg.name 1775 yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True)) 1776 yield '\n' 1777 1778 yield '/* Field tags (for use in manual encoding/decoding) */\n' 1779 for msg in sort_dependencies(self.messages): 1780 for field in msg.fields: 1781 yield field.tags() 1782 for extension in self.extensions: 1783 yield extension.tags() 1784 yield '\n' 1785 1786 yield '/* Struct field encoding specification for nanopb */\n' 1787 for msg in self.messages: 1788 yield msg.fields_declaration(self.dependencies) + '\n' 1789 for msg in self.messages: 1790 yield 'extern const pb_msgdesc_t %s_msg;\n' % msg.name 1791 yield '\n' 1792 1793 yield '/* Defines for backwards compatibility with code written before nanopb-0.4.0 */\n' 1794 for msg in self.messages: 1795 yield '#define %s_fields &%s_msg\n' % (msg.name, msg.name) 1796 yield '\n' 1797 1798 yield '/* Maximum encoded size of messages (where known) */\n' 1799 messagesizes = [] 1800 for msg in self.messages: 1801 identifier = '%s_size' % msg.name 1802 messagesizes.append((identifier, msg.encoded_size(self.dependencies))) 1803 1804 # If we require a symbol from another file, put a preprocessor if statement 1805 # around it to prevent compilation errors if the symbol is not actually available. 1806 local_defines = [identifier for identifier, msize in messagesizes if msize is not None] 1807 guards = {} 1808 for identifier, msize in messagesizes: 1809 if msize is not None: 1810 cpp_guard = msize.get_cpp_guard(local_defines) 1811 if cpp_guard not in guards: 1812 guards[cpp_guard] = set() 1813 for decl in msize.get_declarations().splitlines(): 1814 guards[cpp_guard].add(decl) 1815 guards[cpp_guard].add('#define %-40s %s' % (identifier, msize)) 1816 else: 1817 yield '/* %s depends on runtime parameters */\n' % identifier 1818 for guard, values in guards.items(): 1819 if guard: 1820 yield guard 1821 for v in sorted(values): 1822 yield v 1823 yield '\n' 1824 if guard: 1825 yield '#endif\n' 1826 yield '\n' 1827 1828 if [msg for msg in self.messages if hasattr(msg,'msgid')]: 1829 yield '/* Message IDs (where set with "msgid" option) */\n' 1830 for msg in self.messages: 1831 if hasattr(msg,'msgid'): 1832 yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name) 1833 yield '\n' 1834 1835 symbol = make_identifier(headername.split('.')[0]) 1836 yield '#define %s_MESSAGES \\\n' % symbol 1837 1838 for msg in self.messages: 1839 m = "-1" 1840 msize = msg.encoded_size(self.dependencies) 1841 if msize is not None: 1842 m = msize 1843 if hasattr(msg,'msgid'): 1844 yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name) 1845 yield '\n' 1846 1847 for msg in self.messages: 1848 if hasattr(msg,'msgid'): 1849 yield '#define %s_msgid %d\n' % (msg.name, msg.msgid) 1850 yield '\n' 1851 1852 yield '#ifdef __cplusplus\n' 1853 yield '} /* extern "C" */\n' 1854 yield '#endif\n' 1855 1856 if options.cpp_descriptors: 1857 yield '\n' 1858 yield '#ifdef __cplusplus\n' 1859 yield '/* Message descriptors for nanopb */\n' 1860 yield 'namespace nanopb {\n' 1861 for msg in self.messages: 1862 yield msg.fields_declaration_cpp_lookup() + '\n' 1863 yield '} // namespace nanopb\n' 1864 yield '\n' 1865 yield '#endif /* __cplusplus */\n' 1866 yield '\n' 1867 1868 if Globals.protoc_insertion_points: 1869 yield '/* @@protoc_insertion_point(eof) */\n' 1870 1871 # End of header 1872 yield '\n#endif\n' 1873 1874 def generate_source(self, headername, options): 1875 '''Generate content for a source file.''' 1876 1877 yield '/* Automatically generated nanopb constant definitions */\n' 1878 if options.notimestamp: 1879 yield '/* Generated by %s */\n\n' % (nanopb_version) 1880 else: 1881 yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) 1882 yield options.genformat % (headername) 1883 yield '\n' 1884 1885 if Globals.protoc_insertion_points: 1886 yield '/* @@protoc_insertion_point(includes) */\n' 1887 1888 yield '#if PB_PROTO_HEADER_VERSION != 40\n' 1889 yield '#error Regenerate this file with the current version of nanopb generator.\n' 1890 yield '#endif\n' 1891 yield '\n' 1892 1893 for msg in self.messages: 1894 yield msg.fields_definition(self.dependencies) + '\n\n' 1895 1896 for ext in self.extensions: 1897 yield ext.extension_def(self.dependencies) + '\n' 1898 1899 for enum in self.enums: 1900 yield enum.enum_to_string_definition() + '\n' 1901 1902 # Add checks for numeric limits 1903 if self.messages: 1904 largest_msg = max(self.messages, key = lambda m: m.count_required_fields()) 1905 largest_count = largest_msg.count_required_fields() 1906 if largest_count > 64: 1907 yield '\n/* Check that missing required fields will be properly detected */\n' 1908 yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count 1909 yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name 1910 yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count 1911 yield '#endif\n' 1912 1913 # Add check for sizeof(double) 1914 has_double = False 1915 for msg in self.messages: 1916 for field in msg.all_fields(): 1917 if field.ctype == 'double': 1918 has_double = True 1919 1920 if has_double: 1921 yield '\n' 1922 yield '#ifndef PB_CONVERT_DOUBLE_FLOAT\n' 1923 yield '/* On some platforms (such as AVR), double is really float.\n' 1924 yield ' * To be able to encode/decode double on these platforms, you need.\n' 1925 yield ' * to define PB_CONVERT_DOUBLE_FLOAT in pb.h or compiler command line.\n' 1926 yield ' */\n' 1927 yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' 1928 yield '#endif\n' 1929 1930 yield '\n' 1931 1932 if Globals.protoc_insertion_points: 1933 yield '/* @@protoc_insertion_point(eof) */\n' 1934 1935# --------------------------------------------------------------------------- 1936# Options parsing for the .proto files 1937# --------------------------------------------------------------------------- 1938 1939from fnmatch import fnmatchcase 1940 1941def read_options_file(infile): 1942 '''Parse a separate options file to list: 1943 [(namemask, options), ...] 1944 ''' 1945 results = [] 1946 data = infile.read() 1947 data = re.sub(r'/\*.*?\*/', '', data, flags = re.MULTILINE) 1948 data = re.sub(r'//.*?$', '', data, flags = re.MULTILINE) 1949 data = re.sub(r'#.*?$', '', data, flags = re.MULTILINE) 1950 for i, line in enumerate(data.split('\n')): 1951 line = line.strip() 1952 if not line: 1953 continue 1954 1955 parts = line.split(None, 1) 1956 1957 if len(parts) < 2: 1958 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 1959 "Option lines should have space between field name and options. " + 1960 "Skipping line: '%s'\n" % line) 1961 sys.exit(1) 1962 1963 opts = nanopb_pb2.NanoPBOptions() 1964 1965 try: 1966 text_format.Merge(parts[1], opts) 1967 except Exception as e: 1968 sys.stderr.write("%s:%d: " % (infile.name, i + 1) + 1969 "Unparseable option line: '%s'. " % line + 1970 "Error: %s\n" % str(e)) 1971 sys.exit(1) 1972 results.append((parts[0], opts)) 1973 1974 return results 1975 1976def get_nanopb_suboptions(subdesc, options, name): 1977 '''Get copy of options, and merge information from subdesc.''' 1978 new_options = nanopb_pb2.NanoPBOptions() 1979 new_options.CopyFrom(options) 1980 1981 if hasattr(subdesc, 'syntax') and subdesc.syntax == "proto3": 1982 new_options.proto3 = True 1983 1984 # Handle options defined in a separate file 1985 dotname = '.'.join(name.parts) 1986 for namemask, options in Globals.separate_options: 1987 if fnmatchcase(dotname, namemask): 1988 Globals.matched_namemasks.add(namemask) 1989 new_options.MergeFrom(options) 1990 1991 # Handle options defined in .proto 1992 if isinstance(subdesc.options, descriptor.FieldOptions): 1993 ext_type = nanopb_pb2.nanopb 1994 elif isinstance(subdesc.options, descriptor.FileOptions): 1995 ext_type = nanopb_pb2.nanopb_fileopt 1996 elif isinstance(subdesc.options, descriptor.MessageOptions): 1997 ext_type = nanopb_pb2.nanopb_msgopt 1998 elif isinstance(subdesc.options, descriptor.EnumOptions): 1999 ext_type = nanopb_pb2.nanopb_enumopt 2000 else: 2001 raise Exception("Unknown options type") 2002 2003 if subdesc.options.HasExtension(ext_type): 2004 ext = subdesc.options.Extensions[ext_type] 2005 new_options.MergeFrom(ext) 2006 2007 if Globals.verbose_options: 2008 sys.stderr.write("Options for " + dotname + ": ") 2009 sys.stderr.write(text_format.MessageToString(new_options) + "\n") 2010 2011 return new_options 2012 2013 2014# --------------------------------------------------------------------------- 2015# Command line interface 2016# --------------------------------------------------------------------------- 2017 2018import sys 2019import os.path 2020from optparse import OptionParser 2021 2022optparser = OptionParser( 2023 usage = "Usage: nanopb_generator.py [options] file.pb ...", 2024 epilog = "Compile file.pb from file.proto by: 'protoc -ofile.pb file.proto'. " + 2025 "Output will be written to file.pb.h and file.pb.c.") 2026optparser.add_option("--version", dest="version", action="store_true", 2027 help="Show version info and exit") 2028optparser.add_option("-x", dest="exclude", metavar="FILE", action="append", default=[], 2029 help="Exclude file from generated #include list.") 2030optparser.add_option("-e", "--extension", dest="extension", metavar="EXTENSION", default=".pb", 2031 help="Set extension to use instead of '.pb' for generated files. [default: %default]") 2032optparser.add_option("-H", "--header-extension", dest="header_extension", metavar="EXTENSION", default=".h", 2033 help="Set extension to use for generated header files. [default: %default]") 2034optparser.add_option("-S", "--source-extension", dest="source_extension", metavar="EXTENSION", default=".c", 2035 help="Set extension to use for generated source files. [default: %default]") 2036optparser.add_option("-f", "--options-file", dest="options_file", metavar="FILE", default="%s.options", 2037 help="Set name of a separate generator options file.") 2038optparser.add_option("-I", "--options-path", dest="options_path", metavar="DIR", 2039 action="append", default = [], 2040 help="Search for .options files additionally in this path") 2041optparser.add_option("--error-on-unmatched", dest="error_on_unmatched", action="store_true", default=False, 2042 help ="Stop generation if there are unmatched fields in options file") 2043optparser.add_option("--no-error-on-unmatched", dest="error_on_unmatched", action="store_false", default=False, 2044 help ="Continue generation if there are unmatched fields in options file (default)") 2045optparser.add_option("-D", "--output-dir", dest="output_dir", 2046 metavar="OUTPUTDIR", default=None, 2047 help="Output directory of .pb.h and .pb.c files") 2048optparser.add_option("-Q", "--generated-include-format", dest="genformat", 2049 metavar="FORMAT", default='#include "%s"', 2050 help="Set format string to use for including other .pb.h files. [default: %default]") 2051optparser.add_option("-L", "--library-include-format", dest="libformat", 2052 metavar="FORMAT", default='#include <%s>', 2053 help="Set format string to use for including the nanopb pb.h header. [default: %default]") 2054optparser.add_option("--strip-path", dest="strip_path", action="store_true", default=False, 2055 help="Strip directory path from #included .pb.h file name") 2056optparser.add_option("--no-strip-path", dest="strip_path", action="store_false", 2057 help="Opposite of --strip-path (default since 0.4.0)") 2058optparser.add_option("--cpp-descriptors", action="store_true", 2059 help="Generate C++ descriptors to lookup by type (e.g. pb_field_t for a message)") 2060optparser.add_option("-T", "--no-timestamp", dest="notimestamp", action="store_true", default=True, 2061 help="Don't add timestamp to .pb.h and .pb.c preambles (default since 0.4.0)") 2062optparser.add_option("-t", "--timestamp", dest="notimestamp", action="store_false", default=True, 2063 help="Add timestamp to .pb.h and .pb.c preambles") 2064optparser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, 2065 help="Don't print anything except errors.") 2066optparser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False, 2067 help="Print more information.") 2068optparser.add_option("-s", dest="settings", metavar="OPTION:VALUE", action="append", default=[], 2069 help="Set generator option (max_size, max_count etc.).") 2070optparser.add_option("--protoc-insertion-points", dest="protoc_insertion_points", action="store_true", default=False, 2071 help="Include insertion point comments in output for use by custom protoc plugins") 2072 2073def parse_file(filename, fdesc, options): 2074 '''Parse a single file. Returns a ProtoFile instance.''' 2075 toplevel_options = nanopb_pb2.NanoPBOptions() 2076 for s in options.settings: 2077 text_format.Merge(s, toplevel_options) 2078 2079 if not fdesc: 2080 data = open(filename, 'rb').read() 2081 fdesc = descriptor.FileDescriptorSet.FromString(data).file[0] 2082 2083 # Check if there is a separate .options file 2084 had_abspath = False 2085 try: 2086 optfilename = options.options_file % os.path.splitext(filename)[0] 2087 except TypeError: 2088 # No %s specified, use the filename as-is 2089 optfilename = options.options_file 2090 had_abspath = True 2091 2092 paths = ['.'] + options.options_path 2093 for p in paths: 2094 if os.path.isfile(os.path.join(p, optfilename)): 2095 optfilename = os.path.join(p, optfilename) 2096 if options.verbose: 2097 sys.stderr.write('Reading options from ' + optfilename + '\n') 2098 Globals.separate_options = read_options_file(open(optfilename, openmode_unicode)) 2099 break 2100 else: 2101 # If we are given a full filename and it does not exist, give an error. 2102 # However, don't give error when we automatically look for .options file 2103 # with the same name as .proto. 2104 if options.verbose or had_abspath: 2105 sys.stderr.write('Options file not found: ' + optfilename + '\n') 2106 Globals.separate_options = [] 2107 2108 Globals.matched_namemasks = set() 2109 Globals.protoc_insertion_points = options.protoc_insertion_points 2110 2111 # Parse the file 2112 file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename])) 2113 f = ProtoFile(fdesc, file_options) 2114 f.optfilename = optfilename 2115 2116 return f 2117 2118def process_file(filename, fdesc, options, other_files = {}): 2119 '''Process a single file. 2120 filename: The full path to the .proto or .pb source file, as string. 2121 fdesc: The loaded FileDescriptorSet, or None to read from the input file. 2122 options: Command line options as they come from OptionsParser. 2123 2124 Returns a dict: 2125 {'headername': Name of header file, 2126 'headerdata': Data for the .h header file, 2127 'sourcename': Name of the source code file, 2128 'sourcedata': Data for the .c source code file 2129 } 2130 ''' 2131 f = parse_file(filename, fdesc, options) 2132 2133 # Check the list of dependencies, and if they are available in other_files, 2134 # add them to be considered for import resolving. Recursively add any files 2135 # imported by the dependencies. 2136 deps = list(f.fdesc.dependency) 2137 while deps: 2138 dep = deps.pop(0) 2139 if dep in other_files: 2140 f.add_dependency(other_files[dep]) 2141 deps += list(other_files[dep].fdesc.dependency) 2142 2143 # Decide the file names 2144 noext = os.path.splitext(filename)[0] 2145 headername = noext + options.extension + options.header_extension 2146 sourcename = noext + options.extension + options.source_extension 2147 2148 if options.strip_path: 2149 headerbasename = os.path.basename(headername) 2150 else: 2151 headerbasename = headername 2152 2153 # List of .proto files that should not be included in the C header file 2154 # even if they are mentioned in the source .proto. 2155 excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude + list(f.file_options.exclude) 2156 includes = [d for d in f.fdesc.dependency if d not in excludes] 2157 2158 headerdata = ''.join(f.generate_header(includes, headerbasename, options)) 2159 sourcedata = ''.join(f.generate_source(headerbasename, options)) 2160 2161 # Check if there were any lines in .options that did not match a member 2162 unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks] 2163 if unmatched: 2164 if options.error_on_unmatched: 2165 raise Exception("Following patterns in " + f.optfilename + " did not match any fields: " 2166 + ', '.join(unmatched)); 2167 elif not options.quiet: 2168 sys.stderr.write("Following patterns in " + f.optfilename + " did not match any fields: " 2169 + ', '.join(unmatched) + "\n") 2170 2171 if not Globals.verbose_options: 2172 sys.stderr.write("Use protoc --nanopb-out=-v:. to see a list of the field names.\n") 2173 2174 return {'headername': headername, 'headerdata': headerdata, 2175 'sourcename': sourcename, 'sourcedata': sourcedata} 2176 2177def main_cli(): 2178 '''Main function when invoked directly from the command line.''' 2179 2180 options, filenames = optparser.parse_args() 2181 2182 if options.version: 2183 print(nanopb_version) 2184 sys.exit(0) 2185 2186 if not filenames: 2187 optparser.print_help() 2188 sys.exit(1) 2189 2190 if options.quiet: 2191 options.verbose = False 2192 2193 if options.output_dir and not os.path.exists(options.output_dir): 2194 optparser.print_help() 2195 sys.stderr.write("\noutput_dir does not exist: %s\n" % options.output_dir) 2196 sys.exit(1) 2197 2198 if options.verbose: 2199 sys.stderr.write("Nanopb version %s\n" % nanopb_version) 2200 sys.stderr.write('Google Python protobuf library imported from %s, version %s\n' 2201 % (google.protobuf.__file__, google.protobuf.__version__)) 2202 2203 # Load .pb files into memory and compile any .proto files. 2204 fdescs = {} 2205 include_path = ['-I%s' % p for p in options.options_path] 2206 for filename in filenames: 2207 if filename.endswith(".proto"): 2208 with TemporaryDirectory() as tmpdir: 2209 tmpname = os.path.join(tmpdir, os.path.basename(filename) + ".pb") 2210 status = invoke_protoc(["protoc"] + include_path + ['--include_imports', '--include_source_info', '-o' + tmpname, filename]) 2211 if status != 0: sys.exit(status) 2212 data = open(tmpname, 'rb').read() 2213 else: 2214 data = open(filename, 'rb').read() 2215 2216 fdesc = descriptor.FileDescriptorSet.FromString(data).file[-1] 2217 fdescs[fdesc.name] = fdesc 2218 2219 # Process any include files first, in order to have them 2220 # available as dependencies 2221 other_files = {} 2222 for fdesc in fdescs.values(): 2223 other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options) 2224 2225 # Then generate the headers / sources 2226 Globals.verbose_options = options.verbose 2227 for fdesc in fdescs.values(): 2228 results = process_file(fdesc.name, fdesc, options, other_files) 2229 2230 base_dir = options.output_dir or '' 2231 to_write = [ 2232 (os.path.join(base_dir, results['headername']), results['headerdata']), 2233 (os.path.join(base_dir, results['sourcename']), results['sourcedata']), 2234 ] 2235 2236 if not options.quiet: 2237 paths = " and ".join([x[0] for x in to_write]) 2238 sys.stderr.write("Writing to %s\n" % paths) 2239 2240 for path, data in to_write: 2241 dirname = os.path.dirname(path) 2242 if dirname and not os.path.exists(dirname): 2243 os.makedirs(dirname) 2244 2245 with open(path, 'w') as f: 2246 f.write(data) 2247 2248def main_plugin(): 2249 '''Main function when invoked as a protoc plugin.''' 2250 2251 import io, sys 2252 if sys.platform == "win32": 2253 import os, msvcrt 2254 # Set stdin and stdout to binary mode 2255 msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) 2256 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY) 2257 2258 data = io.open(sys.stdin.fileno(), "rb").read() 2259 2260 request = plugin_pb2.CodeGeneratorRequest.FromString(data) 2261 2262 try: 2263 # Versions of Python prior to 2.7.3 do not support unicode 2264 # input to shlex.split(). Try to convert to str if possible. 2265 params = str(request.parameter) 2266 except UnicodeEncodeError: 2267 params = request.parameter 2268 2269 import shlex 2270 args = shlex.split(params) 2271 2272 if len(args) == 1 and ',' in args[0]: 2273 # For compatibility with other protoc plugins, support options 2274 # separated by comma. 2275 lex = shlex.shlex(params) 2276 lex.whitespace_split = True 2277 lex.whitespace = ',' 2278 lex.commenters = '' 2279 args = list(lex) 2280 2281 optparser.usage = "Usage: protoc --nanopb_out=[options][,more_options]:outdir file.proto" 2282 optparser.epilog = "Output will be written to file.pb.h and file.pb.c." 2283 2284 if '-h' in args or '--help' in args: 2285 # By default optparser prints help to stdout, which doesn't work for 2286 # protoc plugins. 2287 optparser.print_help(sys.stderr) 2288 sys.exit(1) 2289 2290 options, dummy = optparser.parse_args(args) 2291 2292 if options.version: 2293 sys.stderr.write('%s\n' % (nanopb_version)) 2294 sys.exit(0) 2295 2296 Globals.verbose_options = options.verbose 2297 2298 if options.verbose: 2299 sys.stderr.write("Nanopb version %s\n" % nanopb_version) 2300 sys.stderr.write('Google Python protobuf library imported from %s, version %s\n' 2301 % (google.protobuf.__file__, google.protobuf.__version__)) 2302 2303 response = plugin_pb2.CodeGeneratorResponse() 2304 2305 # Google's protoc does not currently indicate the full path of proto files. 2306 # Instead always add the main file path to the search dirs, that works for 2307 # the common case. 2308 import os.path 2309 options.options_path.append(os.path.dirname(request.file_to_generate[0])) 2310 2311 # Process any include files first, in order to have them 2312 # available as dependencies 2313 other_files = {} 2314 for fdesc in request.proto_file: 2315 other_files[fdesc.name] = parse_file(fdesc.name, fdesc, options) 2316 2317 for filename in request.file_to_generate: 2318 for fdesc in request.proto_file: 2319 if fdesc.name == filename: 2320 results = process_file(filename, fdesc, options, other_files) 2321 2322 f = response.file.add() 2323 f.name = results['headername'] 2324 f.content = results['headerdata'] 2325 2326 f = response.file.add() 2327 f.name = results['sourcename'] 2328 f.content = results['sourcedata'] 2329 2330 if hasattr(plugin_pb2.CodeGeneratorResponse, "FEATURE_PROTO3_OPTIONAL"): 2331 response.supported_features = plugin_pb2.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL 2332 2333 io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString()) 2334 2335if __name__ == '__main__': 2336 # Check if we are running as a plugin under protoc 2337 if 'protoc-gen-' in sys.argv[0] or '--protoc-plugin' in sys.argv: 2338 main_plugin() 2339 else: 2340 main_cli() 2341