1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits 21from struct import pack, unpack 22 23from ..compat import binary_to_str, str_to_binary 24 25__all__ = ['TCompactProtocol', 'TCompactProtocolFactory'] 26 27CLEAR = 0 28FIELD_WRITE = 1 29VALUE_WRITE = 2 30CONTAINER_WRITE = 3 31BOOL_WRITE = 4 32FIELD_READ = 5 33CONTAINER_READ = 6 34VALUE_READ = 7 35BOOL_READ = 8 36 37 38def make_helper(v_from, container): 39 def helper(func): 40 def nested(self, *args, **kwargs): 41 assert self.state in (v_from, container), (self.state, v_from, container) 42 return func(self, *args, **kwargs) 43 return nested 44 return helper 45 46 47writer = make_helper(VALUE_WRITE, CONTAINER_WRITE) 48reader = make_helper(VALUE_READ, CONTAINER_READ) 49 50 51def makeZigZag(n, bits): 52 checkIntegerLimits(n, bits) 53 return (n << 1) ^ (n >> (bits - 1)) 54 55 56def fromZigZag(n): 57 return (n >> 1) ^ -(n & 1) 58 59 60def writeVarint(trans, n): 61 assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!" 62 out = bytearray() 63 while True: 64 if n & ~0x7f == 0: 65 out.append(n) 66 break 67 else: 68 out.append((n & 0xff) | 0x80) 69 n = n >> 7 70 trans.write(bytes(out)) 71 72 73def readVarint(trans): 74 result = 0 75 shift = 0 76 while True: 77 x = trans.readAll(1) 78 byte = ord(x) 79 result |= (byte & 0x7f) << shift 80 if byte >> 7 == 0: 81 return result 82 shift += 7 83 84 85class CompactType(object): 86 STOP = 0x00 87 TRUE = 0x01 88 FALSE = 0x02 89 BYTE = 0x03 90 I16 = 0x04 91 I32 = 0x05 92 I64 = 0x06 93 DOUBLE = 0x07 94 BINARY = 0x08 95 LIST = 0x09 96 SET = 0x0A 97 MAP = 0x0B 98 STRUCT = 0x0C 99 100 101CTYPES = { 102 TType.STOP: CompactType.STOP, 103 TType.BOOL: CompactType.TRUE, # used for collection 104 TType.BYTE: CompactType.BYTE, 105 TType.I16: CompactType.I16, 106 TType.I32: CompactType.I32, 107 TType.I64: CompactType.I64, 108 TType.DOUBLE: CompactType.DOUBLE, 109 TType.STRING: CompactType.BINARY, 110 TType.STRUCT: CompactType.STRUCT, 111 TType.LIST: CompactType.LIST, 112 TType.SET: CompactType.SET, 113 TType.MAP: CompactType.MAP, 114} 115 116TTYPES = {} 117for k, v in CTYPES.items(): 118 TTYPES[v] = k 119TTYPES[CompactType.FALSE] = TType.BOOL 120del k 121del v 122 123 124class TCompactProtocol(TProtocolBase): 125 """Compact implementation of the Thrift protocol driver.""" 126 127 PROTOCOL_ID = 0x82 128 VERSION = 1 129 VERSION_MASK = 0x1f 130 TYPE_MASK = 0xe0 131 TYPE_BITS = 0x07 132 TYPE_SHIFT_AMOUNT = 5 133 134 def __init__(self, trans, 135 string_length_limit=None, 136 container_length_limit=None): 137 TProtocolBase.__init__(self, trans) 138 self.state = CLEAR 139 self.__last_fid = 0 140 self.__bool_fid = None 141 self.__bool_value = None 142 self.__structs = [] 143 self.__containers = [] 144 self.string_length_limit = string_length_limit 145 self.container_length_limit = container_length_limit 146 147 def _check_string_length(self, length): 148 self._check_length(self.string_length_limit, length) 149 150 def _check_container_length(self, length): 151 self._check_length(self.container_length_limit, length) 152 153 def __writeVarint(self, n): 154 writeVarint(self.trans, n) 155 156 def writeMessageBegin(self, name, type, seqid): 157 assert self.state == CLEAR 158 self.__writeUByte(self.PROTOCOL_ID) 159 self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT)) 160 # The sequence id is a signed 32-bit integer but the compact protocol 161 # writes this out as a "var int" which is always positive, and attempting 162 # to write a negative number results in an infinite loop, so we may 163 # need to do some conversion here... 164 tseqid = seqid 165 if tseqid < 0: 166 tseqid = 2147483648 + (2147483648 + tseqid) 167 self.__writeVarint(tseqid) 168 self.__writeBinary(str_to_binary(name)) 169 self.state = VALUE_WRITE 170 171 def writeMessageEnd(self): 172 assert self.state == VALUE_WRITE 173 self.state = CLEAR 174 175 def writeStructBegin(self, name): 176 assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state 177 self.__structs.append((self.state, self.__last_fid)) 178 self.state = FIELD_WRITE 179 self.__last_fid = 0 180 181 def writeStructEnd(self): 182 assert self.state == FIELD_WRITE 183 self.state, self.__last_fid = self.__structs.pop() 184 185 def writeFieldStop(self): 186 self.__writeByte(0) 187 188 def __writeFieldHeader(self, type, fid): 189 delta = fid - self.__last_fid 190 if 0 < delta <= 15: 191 self.__writeUByte(delta << 4 | type) 192 else: 193 self.__writeByte(type) 194 self.__writeI16(fid) 195 self.__last_fid = fid 196 197 def writeFieldBegin(self, name, type, fid): 198 assert self.state == FIELD_WRITE, self.state 199 if type == TType.BOOL: 200 self.state = BOOL_WRITE 201 self.__bool_fid = fid 202 else: 203 self.state = VALUE_WRITE 204 self.__writeFieldHeader(CTYPES[type], fid) 205 206 def writeFieldEnd(self): 207 assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state 208 self.state = FIELD_WRITE 209 210 def __writeUByte(self, byte): 211 self.trans.write(pack('!B', byte)) 212 213 def __writeByte(self, byte): 214 self.trans.write(pack('!b', byte)) 215 216 def __writeI16(self, i16): 217 self.__writeVarint(makeZigZag(i16, 16)) 218 219 def __writeSize(self, i32): 220 self.__writeVarint(i32) 221 222 def writeCollectionBegin(self, etype, size): 223 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state 224 if size <= 14: 225 self.__writeUByte(size << 4 | CTYPES[etype]) 226 else: 227 self.__writeUByte(0xf0 | CTYPES[etype]) 228 self.__writeSize(size) 229 self.__containers.append(self.state) 230 self.state = CONTAINER_WRITE 231 writeSetBegin = writeCollectionBegin 232 writeListBegin = writeCollectionBegin 233 234 def writeMapBegin(self, ktype, vtype, size): 235 assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state 236 if size == 0: 237 self.__writeByte(0) 238 else: 239 self.__writeSize(size) 240 self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype]) 241 self.__containers.append(self.state) 242 self.state = CONTAINER_WRITE 243 244 def writeCollectionEnd(self): 245 assert self.state == CONTAINER_WRITE, self.state 246 self.state = self.__containers.pop() 247 writeMapEnd = writeCollectionEnd 248 writeSetEnd = writeCollectionEnd 249 writeListEnd = writeCollectionEnd 250 251 def writeBool(self, bool): 252 if self.state == BOOL_WRITE: 253 if bool: 254 ctype = CompactType.TRUE 255 else: 256 ctype = CompactType.FALSE 257 self.__writeFieldHeader(ctype, self.__bool_fid) 258 elif self.state == CONTAINER_WRITE: 259 if bool: 260 self.__writeByte(CompactType.TRUE) 261 else: 262 self.__writeByte(CompactType.FALSE) 263 else: 264 raise AssertionError("Invalid state in compact protocol") 265 266 writeByte = writer(__writeByte) 267 writeI16 = writer(__writeI16) 268 269 @writer 270 def writeI32(self, i32): 271 self.__writeVarint(makeZigZag(i32, 32)) 272 273 @writer 274 def writeI64(self, i64): 275 self.__writeVarint(makeZigZag(i64, 64)) 276 277 @writer 278 def writeDouble(self, dub): 279 self.trans.write(pack('<d', dub)) 280 281 def __writeBinary(self, s): 282 self.__writeSize(len(s)) 283 self.trans.write(s) 284 writeBinary = writer(__writeBinary) 285 286 def readFieldBegin(self): 287 assert self.state == FIELD_READ, self.state 288 type = self.__readUByte() 289 if type & 0x0f == TType.STOP: 290 return (None, 0, 0) 291 delta = type >> 4 292 if delta == 0: 293 fid = self.__readI16() 294 else: 295 fid = self.__last_fid + delta 296 self.__last_fid = fid 297 type = type & 0x0f 298 if type == CompactType.TRUE: 299 self.state = BOOL_READ 300 self.__bool_value = True 301 elif type == CompactType.FALSE: 302 self.state = BOOL_READ 303 self.__bool_value = False 304 else: 305 self.state = VALUE_READ 306 return (None, self.__getTType(type), fid) 307 308 def readFieldEnd(self): 309 assert self.state in (VALUE_READ, BOOL_READ), self.state 310 self.state = FIELD_READ 311 312 def __readUByte(self): 313 result, = unpack('!B', self.trans.readAll(1)) 314 return result 315 316 def __readByte(self): 317 result, = unpack('!b', self.trans.readAll(1)) 318 return result 319 320 def __readVarint(self): 321 return readVarint(self.trans) 322 323 def __readZigZag(self): 324 return fromZigZag(self.__readVarint()) 325 326 def __readSize(self): 327 result = self.__readVarint() 328 if result < 0: 329 raise TProtocolException("Length < 0") 330 return result 331 332 def readMessageBegin(self): 333 assert self.state == CLEAR 334 proto_id = self.__readUByte() 335 if proto_id != self.PROTOCOL_ID: 336 raise TProtocolException(TProtocolException.BAD_VERSION, 337 'Bad protocol id in the message: %d' % proto_id) 338 ver_type = self.__readUByte() 339 type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS 340 version = ver_type & self.VERSION_MASK 341 if version != self.VERSION: 342 raise TProtocolException(TProtocolException.BAD_VERSION, 343 'Bad version: %d (expect %d)' % (version, self.VERSION)) 344 seqid = self.__readVarint() 345 # the sequence is a compact "var int" which is treaded as unsigned, 346 # however the sequence is actually signed... 347 if seqid > 2147483647: 348 seqid = -2147483648 - (2147483648 - seqid) 349 name = binary_to_str(self.__readBinary()) 350 return (name, type, seqid) 351 352 def readMessageEnd(self): 353 assert self.state == CLEAR 354 assert len(self.__structs) == 0 355 356 def readStructBegin(self): 357 assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state 358 self.__structs.append((self.state, self.__last_fid)) 359 self.state = FIELD_READ 360 self.__last_fid = 0 361 362 def readStructEnd(self): 363 assert self.state == FIELD_READ 364 self.state, self.__last_fid = self.__structs.pop() 365 366 def readCollectionBegin(self): 367 assert self.state in (VALUE_READ, CONTAINER_READ), self.state 368 size_type = self.__readUByte() 369 size = size_type >> 4 370 type = self.__getTType(size_type) 371 if size == 15: 372 size = self.__readSize() 373 self._check_container_length(size) 374 self.__containers.append(self.state) 375 self.state = CONTAINER_READ 376 return type, size 377 readSetBegin = readCollectionBegin 378 readListBegin = readCollectionBegin 379 380 def readMapBegin(self): 381 assert self.state in (VALUE_READ, CONTAINER_READ), self.state 382 size = self.__readSize() 383 self._check_container_length(size) 384 types = 0 385 if size > 0: 386 types = self.__readUByte() 387 vtype = self.__getTType(types) 388 ktype = self.__getTType(types >> 4) 389 self.__containers.append(self.state) 390 self.state = CONTAINER_READ 391 return (ktype, vtype, size) 392 393 def readCollectionEnd(self): 394 assert self.state == CONTAINER_READ, self.state 395 self.state = self.__containers.pop() 396 readSetEnd = readCollectionEnd 397 readListEnd = readCollectionEnd 398 readMapEnd = readCollectionEnd 399 400 def readBool(self): 401 if self.state == BOOL_READ: 402 return self.__bool_value == CompactType.TRUE 403 elif self.state == CONTAINER_READ: 404 return self.__readByte() == CompactType.TRUE 405 else: 406 raise AssertionError("Invalid state in compact protocol: %d" % 407 self.state) 408 409 readByte = reader(__readByte) 410 __readI16 = __readZigZag 411 readI16 = reader(__readZigZag) 412 readI32 = reader(__readZigZag) 413 readI64 = reader(__readZigZag) 414 415 @reader 416 def readDouble(self): 417 buff = self.trans.readAll(8) 418 val, = unpack('<d', buff) 419 return val 420 421 def __readBinary(self): 422 size = self.__readSize() 423 self._check_string_length(size) 424 return self.trans.readAll(size) 425 readBinary = reader(__readBinary) 426 427 def __getTType(self, byte): 428 return TTYPES[byte & 0x0f] 429 430 431class TCompactProtocolFactory(TProtocolFactory): 432 def __init__(self, 433 string_length_limit=None, 434 container_length_limit=None): 435 self.string_length_limit = string_length_limit 436 self.container_length_limit = container_length_limit 437 438 def getProtocol(self, trans): 439 return TCompactProtocol(trans, 440 self.string_length_limit, 441 self.container_length_limit) 442 443 444class TCompactProtocolAccelerated(TCompactProtocol): 445 """C-Accelerated version of TCompactProtocol. 446 447 This class does not override any of TCompactProtocol's methods, 448 but the generated code recognizes it directly and will call into 449 our C module to do the encoding, bypassing this object entirely. 450 We inherit from TCompactProtocol so that the normal TCompactProtocol 451 encoding can happen if the fastbinary module doesn't work for some 452 reason. 453 To disable this behavior, pass fallback=False constructor argument. 454 455 In order to take advantage of the C module, just use 456 TCompactProtocolAccelerated instead of TCompactProtocol. 457 """ 458 pass 459 460 def __init__(self, *args, **kwargs): 461 fallback = kwargs.pop('fallback', True) 462 super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs) 463 try: 464 from thrift.protocol import fastbinary 465 except ImportError: 466 if not fallback: 467 raise 468 else: 469 self._fast_decode = fastbinary.decode_compact 470 self._fast_encode = fastbinary.encode_compact 471 472 473class TCompactProtocolAcceleratedFactory(TProtocolFactory): 474 def __init__(self, 475 string_length_limit=None, 476 container_length_limit=None, 477 fallback=True): 478 self.string_length_limit = string_length_limit 479 self.container_length_limit = container_length_limit 480 self._fallback = fallback 481 482 def getProtocol(self, trans): 483 return TCompactProtocolAccelerated( 484 trans, 485 string_length_limit=self.string_length_limit, 486 container_length_limit=self.container_length_limit, 487 fallback=self._fallback) 488