1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20from thrift.Thrift import TException, TType, TFrozenDict 21from thrift.transport.TTransport import TTransportException 22from ..compat import binary_to_str, str_to_binary 23 24import six 25import sys 26from itertools import islice 27from six.moves import zip 28 29 30class TProtocolException(TException): 31 """Custom Protocol Exception class""" 32 33 UNKNOWN = 0 34 INVALID_DATA = 1 35 NEGATIVE_SIZE = 2 36 SIZE_LIMIT = 3 37 BAD_VERSION = 4 38 NOT_IMPLEMENTED = 5 39 DEPTH_LIMIT = 6 40 INVALID_PROTOCOL = 7 41 42 def __init__(self, type=UNKNOWN, message=None): 43 TException.__init__(self, message) 44 self.type = type 45 46 47class TProtocolBase(object): 48 """Base class for Thrift protocol driver.""" 49 50 def __init__(self, trans): 51 self.trans = trans 52 self._fast_decode = None 53 self._fast_encode = None 54 55 @staticmethod 56 def _check_length(limit, length): 57 if length < 0: 58 raise TTransportException(TTransportException.NEGATIVE_SIZE, 59 'Negative length: %d' % length) 60 if limit is not None and length > limit: 61 raise TTransportException(TTransportException.SIZE_LIMIT, 62 'Length exceeded max allowed: %d' % limit) 63 64 def writeMessageBegin(self, name, ttype, seqid): 65 pass 66 67 def writeMessageEnd(self): 68 pass 69 70 def writeStructBegin(self, name): 71 pass 72 73 def writeStructEnd(self): 74 pass 75 76 def writeFieldBegin(self, name, ttype, fid): 77 pass 78 79 def writeFieldEnd(self): 80 pass 81 82 def writeFieldStop(self): 83 pass 84 85 def writeMapBegin(self, ktype, vtype, size): 86 pass 87 88 def writeMapEnd(self): 89 pass 90 91 def writeListBegin(self, etype, size): 92 pass 93 94 def writeListEnd(self): 95 pass 96 97 def writeSetBegin(self, etype, size): 98 pass 99 100 def writeSetEnd(self): 101 pass 102 103 def writeBool(self, bool_val): 104 pass 105 106 def writeByte(self, byte): 107 pass 108 109 def writeI16(self, i16): 110 pass 111 112 def writeI32(self, i32): 113 pass 114 115 def writeI64(self, i64): 116 pass 117 118 def writeDouble(self, dub): 119 pass 120 121 def writeString(self, str_val): 122 self.writeBinary(str_to_binary(str_val)) 123 124 def writeBinary(self, str_val): 125 pass 126 127 def writeUtf8(self, str_val): 128 self.writeString(str_val.encode('utf8')) 129 130 def readMessageBegin(self): 131 pass 132 133 def readMessageEnd(self): 134 pass 135 136 def readStructBegin(self): 137 pass 138 139 def readStructEnd(self): 140 pass 141 142 def readFieldBegin(self): 143 pass 144 145 def readFieldEnd(self): 146 pass 147 148 def readMapBegin(self): 149 pass 150 151 def readMapEnd(self): 152 pass 153 154 def readListBegin(self): 155 pass 156 157 def readListEnd(self): 158 pass 159 160 def readSetBegin(self): 161 pass 162 163 def readSetEnd(self): 164 pass 165 166 def readBool(self): 167 pass 168 169 def readByte(self): 170 pass 171 172 def readI16(self): 173 pass 174 175 def readI32(self): 176 pass 177 178 def readI64(self): 179 pass 180 181 def readDouble(self): 182 pass 183 184 def readString(self): 185 return binary_to_str(self.readBinary()) 186 187 def readBinary(self): 188 pass 189 190 def readUtf8(self): 191 return self.readString().decode('utf8') 192 193 def skip(self, ttype): 194 if ttype == TType.BOOL: 195 self.readBool() 196 elif ttype == TType.BYTE: 197 self.readByte() 198 elif ttype == TType.I16: 199 self.readI16() 200 elif ttype == TType.I32: 201 self.readI32() 202 elif ttype == TType.I64: 203 self.readI64() 204 elif ttype == TType.DOUBLE: 205 self.readDouble() 206 elif ttype == TType.STRING: 207 self.readString() 208 elif ttype == TType.STRUCT: 209 name = self.readStructBegin() 210 while True: 211 (name, ttype, id) = self.readFieldBegin() 212 if ttype == TType.STOP: 213 break 214 self.skip(ttype) 215 self.readFieldEnd() 216 self.readStructEnd() 217 elif ttype == TType.MAP: 218 (ktype, vtype, size) = self.readMapBegin() 219 for i in range(size): 220 self.skip(ktype) 221 self.skip(vtype) 222 self.readMapEnd() 223 elif ttype == TType.SET: 224 (etype, size) = self.readSetBegin() 225 for i in range(size): 226 self.skip(etype) 227 self.readSetEnd() 228 elif ttype == TType.LIST: 229 (etype, size) = self.readListBegin() 230 for i in range(size): 231 self.skip(etype) 232 self.readListEnd() 233 else: 234 raise TProtocolException( 235 TProtocolException.INVALID_DATA, 236 "invalid TType") 237 238 # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name ) 239 _TTYPE_HANDLERS = ( 240 (None, None, False), # 0 TType.STOP 241 (None, None, False), # 1 TType.VOID # TODO: handle void? 242 ('readBool', 'writeBool', False), # 2 TType.BOOL 243 ('readByte', 'writeByte', False), # 3 TType.BYTE and I08 244 ('readDouble', 'writeDouble', False), # 4 TType.DOUBLE 245 (None, None, False), # 5 undefined 246 ('readI16', 'writeI16', False), # 6 TType.I16 247 (None, None, False), # 7 undefined 248 ('readI32', 'writeI32', False), # 8 TType.I32 249 (None, None, False), # 9 undefined 250 ('readI64', 'writeI64', False), # 10 TType.I64 251 ('readString', 'writeString', False), # 11 TType.STRING and UTF7 252 ('readContainerStruct', 'writeContainerStruct', True), # 12 *.STRUCT 253 ('readContainerMap', 'writeContainerMap', True), # 13 TType.MAP 254 ('readContainerSet', 'writeContainerSet', True), # 14 TType.SET 255 ('readContainerList', 'writeContainerList', True), # 15 TType.LIST 256 (None, None, False), # 16 TType.UTF8 # TODO: handle utf8 types? 257 (None, None, False) # 17 TType.UTF16 # TODO: handle utf16 types? 258 ) 259 260 def _ttype_handlers(self, ttype, spec): 261 if spec == 'BINARY': 262 if ttype != TType.STRING: 263 raise TProtocolException(type=TProtocolException.INVALID_DATA, 264 message='Invalid binary field type %d' % ttype) 265 return ('readBinary', 'writeBinary', False) 266 if sys.version_info[0] == 2 and spec == 'UTF8': 267 if ttype != TType.STRING: 268 raise TProtocolException(type=TProtocolException.INVALID_DATA, 269 message='Invalid string field type %d' % ttype) 270 return ('readUtf8', 'writeUtf8', False) 271 return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False) 272 273 def _read_by_ttype(self, ttype, spec, espec): 274 reader_name, _, is_container = self._ttype_handlers(ttype, espec) 275 if reader_name is None: 276 raise TProtocolException(type=TProtocolException.INVALID_DATA, 277 message='Invalid type %d' % (ttype)) 278 reader_func = getattr(self, reader_name) 279 read = (lambda: reader_func(espec)) if is_container else reader_func 280 while True: 281 yield read() 282 283 def readFieldByTType(self, ttype, spec): 284 return next(self._read_by_ttype(ttype, spec, spec)) 285 286 def readContainerList(self, spec): 287 ttype, tspec, is_immutable = spec 288 (list_type, list_len) = self.readListBegin() 289 # TODO: compare types we just decoded with thrift_spec 290 elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len) 291 results = (tuple if is_immutable else list)(elems) 292 self.readListEnd() 293 return results 294 295 def readContainerSet(self, spec): 296 ttype, tspec, is_immutable = spec 297 (set_type, set_len) = self.readSetBegin() 298 # TODO: compare types we just decoded with thrift_spec 299 elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len) 300 results = (frozenset if is_immutable else set)(elems) 301 self.readSetEnd() 302 return results 303 304 def readContainerStruct(self, spec): 305 (obj_class, obj_spec) = spec 306 307 # If obj_class.read is a classmethod (e.g. in frozen structs), 308 # call it as such. 309 if getattr(obj_class.read, '__self__', None) is obj_class: 310 obj = obj_class.read(self) 311 else: 312 obj = obj_class() 313 obj.read(self) 314 return obj 315 316 def readContainerMap(self, spec): 317 ktype, kspec, vtype, vspec, is_immutable = spec 318 (map_ktype, map_vtype, map_len) = self.readMapBegin() 319 # TODO: compare types we just decoded with thrift_spec and 320 # abort/skip if types disagree 321 keys = self._read_by_ttype(ktype, spec, kspec) 322 vals = self._read_by_ttype(vtype, spec, vspec) 323 keyvals = islice(zip(keys, vals), map_len) 324 results = (TFrozenDict if is_immutable else dict)(keyvals) 325 self.readMapEnd() 326 return results 327 328 def readStruct(self, obj, thrift_spec, is_immutable=False): 329 if is_immutable: 330 fields = {} 331 self.readStructBegin() 332 while True: 333 (fname, ftype, fid) = self.readFieldBegin() 334 if ftype == TType.STOP: 335 break 336 try: 337 field = thrift_spec[fid] 338 except IndexError: 339 self.skip(ftype) 340 else: 341 if field is not None and ftype == field[1]: 342 fname = field[2] 343 fspec = field[3] 344 val = self.readFieldByTType(ftype, fspec) 345 if is_immutable: 346 fields[fname] = val 347 else: 348 setattr(obj, fname, val) 349 else: 350 self.skip(ftype) 351 self.readFieldEnd() 352 self.readStructEnd() 353 if is_immutable: 354 return obj(**fields) 355 356 def writeContainerStruct(self, val, spec): 357 val.write(self) 358 359 def writeContainerList(self, val, spec): 360 ttype, tspec, _ = spec 361 self.writeListBegin(ttype, len(val)) 362 for _ in self._write_by_ttype(ttype, val, spec, tspec): 363 pass 364 self.writeListEnd() 365 366 def writeContainerSet(self, val, spec): 367 ttype, tspec, _ = spec 368 self.writeSetBegin(ttype, len(val)) 369 for _ in self._write_by_ttype(ttype, val, spec, tspec): 370 pass 371 self.writeSetEnd() 372 373 def writeContainerMap(self, val, spec): 374 ktype, kspec, vtype, vspec, _ = spec 375 self.writeMapBegin(ktype, vtype, len(val)) 376 for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec), 377 self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)): 378 pass 379 self.writeMapEnd() 380 381 def writeStruct(self, obj, thrift_spec): 382 self.writeStructBegin(obj.__class__.__name__) 383 for field in thrift_spec: 384 if field is None: 385 continue 386 fname = field[2] 387 val = getattr(obj, fname) 388 if val is None: 389 # skip writing out unset fields 390 continue 391 fid = field[0] 392 ftype = field[1] 393 fspec = field[3] 394 self.writeFieldBegin(fname, ftype, fid) 395 self.writeFieldByTType(ftype, val, fspec) 396 self.writeFieldEnd() 397 self.writeFieldStop() 398 self.writeStructEnd() 399 400 def _write_by_ttype(self, ttype, vals, spec, espec): 401 _, writer_name, is_container = self._ttype_handlers(ttype, espec) 402 writer_func = getattr(self, writer_name) 403 write = (lambda v: writer_func(v, espec)) if is_container else writer_func 404 for v in vals: 405 yield write(v) 406 407 def writeFieldByTType(self, ttype, val, spec): 408 next(self._write_by_ttype(ttype, [val], spec, spec)) 409 410 411def checkIntegerLimits(i, bits): 412 if bits == 8 and (i < -128 or i > 127): 413 raise TProtocolException(TProtocolException.INVALID_DATA, 414 "i8 requires -128 <= number <= 127") 415 elif bits == 16 and (i < -32768 or i > 32767): 416 raise TProtocolException(TProtocolException.INVALID_DATA, 417 "i16 requires -32768 <= number <= 32767") 418 elif bits == 32 and (i < -2147483648 or i > 2147483647): 419 raise TProtocolException(TProtocolException.INVALID_DATA, 420 "i32 requires -2147483648 <= number <= 2147483647") 421 elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807): 422 raise TProtocolException(TProtocolException.INVALID_DATA, 423 "i64 requires -9223372036854775808 <= number <= 9223372036854775807") 424 425 426class TProtocolFactory(object): 427 def getProtocol(self, trans): 428 pass 429