1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from .TProtocol import TType, TProtocolBase, TProtocolException, TProtocolFactory, checkIntegerLimits
21from struct import pack, unpack
22
23from ..compat import binary_to_str, str_to_binary
24
25__all__ = ['TCompactProtocol', 'TCompactProtocolFactory']
26
27CLEAR = 0
28FIELD_WRITE = 1
29VALUE_WRITE = 2
30CONTAINER_WRITE = 3
31BOOL_WRITE = 4
32FIELD_READ = 5
33CONTAINER_READ = 6
34VALUE_READ = 7
35BOOL_READ = 8
36
37
38def make_helper(v_from, container):
39    def helper(func):
40        def nested(self, *args, **kwargs):
41            assert self.state in (v_from, container), (self.state, v_from, container)
42            return func(self, *args, **kwargs)
43        return nested
44    return helper
45
46
47writer = make_helper(VALUE_WRITE, CONTAINER_WRITE)
48reader = make_helper(VALUE_READ, CONTAINER_READ)
49
50
51def makeZigZag(n, bits):
52    checkIntegerLimits(n, bits)
53    return (n << 1) ^ (n >> (bits - 1))
54
55
56def fromZigZag(n):
57    return (n >> 1) ^ -(n & 1)
58
59
60def writeVarint(trans, n):
61    assert n >= 0, "Input to TCompactProtocol writeVarint cannot be negative!"
62    out = bytearray()
63    while True:
64        if n & ~0x7f == 0:
65            out.append(n)
66            break
67        else:
68            out.append((n & 0xff) | 0x80)
69            n = n >> 7
70    trans.write(bytes(out))
71
72
73def readVarint(trans):
74    result = 0
75    shift = 0
76    while True:
77        x = trans.readAll(1)
78        byte = ord(x)
79        result |= (byte & 0x7f) << shift
80        if byte >> 7 == 0:
81            return result
82        shift += 7
83
84
85class CompactType(object):
86    STOP = 0x00
87    TRUE = 0x01
88    FALSE = 0x02
89    BYTE = 0x03
90    I16 = 0x04
91    I32 = 0x05
92    I64 = 0x06
93    DOUBLE = 0x07
94    BINARY = 0x08
95    LIST = 0x09
96    SET = 0x0A
97    MAP = 0x0B
98    STRUCT = 0x0C
99
100
101CTYPES = {
102    TType.STOP: CompactType.STOP,
103    TType.BOOL: CompactType.TRUE,  # used for collection
104    TType.BYTE: CompactType.BYTE,
105    TType.I16: CompactType.I16,
106    TType.I32: CompactType.I32,
107    TType.I64: CompactType.I64,
108    TType.DOUBLE: CompactType.DOUBLE,
109    TType.STRING: CompactType.BINARY,
110    TType.STRUCT: CompactType.STRUCT,
111    TType.LIST: CompactType.LIST,
112    TType.SET: CompactType.SET,
113    TType.MAP: CompactType.MAP,
114}
115
116TTYPES = {}
117for k, v in CTYPES.items():
118    TTYPES[v] = k
119TTYPES[CompactType.FALSE] = TType.BOOL
120del k
121del v
122
123
124class TCompactProtocol(TProtocolBase):
125    """Compact implementation of the Thrift protocol driver."""
126
127    PROTOCOL_ID = 0x82
128    VERSION = 1
129    VERSION_MASK = 0x1f
130    TYPE_MASK = 0xe0
131    TYPE_BITS = 0x07
132    TYPE_SHIFT_AMOUNT = 5
133
134    def __init__(self, trans,
135                 string_length_limit=None,
136                 container_length_limit=None):
137        TProtocolBase.__init__(self, trans)
138        self.state = CLEAR
139        self.__last_fid = 0
140        self.__bool_fid = None
141        self.__bool_value = None
142        self.__structs = []
143        self.__containers = []
144        self.string_length_limit = string_length_limit
145        self.container_length_limit = container_length_limit
146
147    def _check_string_length(self, length):
148        self._check_length(self.string_length_limit, length)
149
150    def _check_container_length(self, length):
151        self._check_length(self.container_length_limit, length)
152
153    def __writeVarint(self, n):
154        writeVarint(self.trans, n)
155
156    def writeMessageBegin(self, name, type, seqid):
157        assert self.state == CLEAR
158        self.__writeUByte(self.PROTOCOL_ID)
159        self.__writeUByte(self.VERSION | (type << self.TYPE_SHIFT_AMOUNT))
160        # The sequence id is a signed 32-bit integer but the compact protocol
161        # writes this out as a "var int" which is always positive, and attempting
162        # to write a negative number results in an infinite loop, so we may
163        # need to do some conversion here...
164        tseqid = seqid
165        if tseqid < 0:
166            tseqid = 2147483648 + (2147483648 + tseqid)
167        self.__writeVarint(tseqid)
168        self.__writeBinary(str_to_binary(name))
169        self.state = VALUE_WRITE
170
171    def writeMessageEnd(self):
172        assert self.state == VALUE_WRITE
173        self.state = CLEAR
174
175    def writeStructBegin(self, name):
176        assert self.state in (CLEAR, CONTAINER_WRITE, VALUE_WRITE), self.state
177        self.__structs.append((self.state, self.__last_fid))
178        self.state = FIELD_WRITE
179        self.__last_fid = 0
180
181    def writeStructEnd(self):
182        assert self.state == FIELD_WRITE
183        self.state, self.__last_fid = self.__structs.pop()
184
185    def writeFieldStop(self):
186        self.__writeByte(0)
187
188    def __writeFieldHeader(self, type, fid):
189        delta = fid - self.__last_fid
190        if 0 < delta <= 15:
191            self.__writeUByte(delta << 4 | type)
192        else:
193            self.__writeByte(type)
194            self.__writeI16(fid)
195        self.__last_fid = fid
196
197    def writeFieldBegin(self, name, type, fid):
198        assert self.state == FIELD_WRITE, self.state
199        if type == TType.BOOL:
200            self.state = BOOL_WRITE
201            self.__bool_fid = fid
202        else:
203            self.state = VALUE_WRITE
204            self.__writeFieldHeader(CTYPES[type], fid)
205
206    def writeFieldEnd(self):
207        assert self.state in (VALUE_WRITE, BOOL_WRITE), self.state
208        self.state = FIELD_WRITE
209
210    def __writeUByte(self, byte):
211        self.trans.write(pack('!B', byte))
212
213    def __writeByte(self, byte):
214        self.trans.write(pack('!b', byte))
215
216    def __writeI16(self, i16):
217        self.__writeVarint(makeZigZag(i16, 16))
218
219    def __writeSize(self, i32):
220        self.__writeVarint(i32)
221
222    def writeCollectionBegin(self, etype, size):
223        assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
224        if size <= 14:
225            self.__writeUByte(size << 4 | CTYPES[etype])
226        else:
227            self.__writeUByte(0xf0 | CTYPES[etype])
228            self.__writeSize(size)
229        self.__containers.append(self.state)
230        self.state = CONTAINER_WRITE
231    writeSetBegin = writeCollectionBegin
232    writeListBegin = writeCollectionBegin
233
234    def writeMapBegin(self, ktype, vtype, size):
235        assert self.state in (VALUE_WRITE, CONTAINER_WRITE), self.state
236        if size == 0:
237            self.__writeByte(0)
238        else:
239            self.__writeSize(size)
240            self.__writeUByte(CTYPES[ktype] << 4 | CTYPES[vtype])
241        self.__containers.append(self.state)
242        self.state = CONTAINER_WRITE
243
244    def writeCollectionEnd(self):
245        assert self.state == CONTAINER_WRITE, self.state
246        self.state = self.__containers.pop()
247    writeMapEnd = writeCollectionEnd
248    writeSetEnd = writeCollectionEnd
249    writeListEnd = writeCollectionEnd
250
251    def writeBool(self, bool):
252        if self.state == BOOL_WRITE:
253            if bool:
254                ctype = CompactType.TRUE
255            else:
256                ctype = CompactType.FALSE
257            self.__writeFieldHeader(ctype, self.__bool_fid)
258        elif self.state == CONTAINER_WRITE:
259            if bool:
260                self.__writeByte(CompactType.TRUE)
261            else:
262                self.__writeByte(CompactType.FALSE)
263        else:
264            raise AssertionError("Invalid state in compact protocol")
265
266    writeByte = writer(__writeByte)
267    writeI16 = writer(__writeI16)
268
269    @writer
270    def writeI32(self, i32):
271        self.__writeVarint(makeZigZag(i32, 32))
272
273    @writer
274    def writeI64(self, i64):
275        self.__writeVarint(makeZigZag(i64, 64))
276
277    @writer
278    def writeDouble(self, dub):
279        self.trans.write(pack('<d', dub))
280
281    def __writeBinary(self, s):
282        self.__writeSize(len(s))
283        self.trans.write(s)
284    writeBinary = writer(__writeBinary)
285
286    def readFieldBegin(self):
287        assert self.state == FIELD_READ, self.state
288        type = self.__readUByte()
289        if type & 0x0f == TType.STOP:
290            return (None, 0, 0)
291        delta = type >> 4
292        if delta == 0:
293            fid = self.__readI16()
294        else:
295            fid = self.__last_fid + delta
296        self.__last_fid = fid
297        type = type & 0x0f
298        if type == CompactType.TRUE:
299            self.state = BOOL_READ
300            self.__bool_value = True
301        elif type == CompactType.FALSE:
302            self.state = BOOL_READ
303            self.__bool_value = False
304        else:
305            self.state = VALUE_READ
306        return (None, self.__getTType(type), fid)
307
308    def readFieldEnd(self):
309        assert self.state in (VALUE_READ, BOOL_READ), self.state
310        self.state = FIELD_READ
311
312    def __readUByte(self):
313        result, = unpack('!B', self.trans.readAll(1))
314        return result
315
316    def __readByte(self):
317        result, = unpack('!b', self.trans.readAll(1))
318        return result
319
320    def __readVarint(self):
321        return readVarint(self.trans)
322
323    def __readZigZag(self):
324        return fromZigZag(self.__readVarint())
325
326    def __readSize(self):
327        result = self.__readVarint()
328        if result < 0:
329            raise TProtocolException("Length < 0")
330        return result
331
332    def readMessageBegin(self):
333        assert self.state == CLEAR
334        proto_id = self.__readUByte()
335        if proto_id != self.PROTOCOL_ID:
336            raise TProtocolException(TProtocolException.BAD_VERSION,
337                                     'Bad protocol id in the message: %d' % proto_id)
338        ver_type = self.__readUByte()
339        type = (ver_type >> self.TYPE_SHIFT_AMOUNT) & self.TYPE_BITS
340        version = ver_type & self.VERSION_MASK
341        if version != self.VERSION:
342            raise TProtocolException(TProtocolException.BAD_VERSION,
343                                     'Bad version: %d (expect %d)' % (version, self.VERSION))
344        seqid = self.__readVarint()
345        # the sequence is a compact "var int" which is treaded as unsigned,
346        # however the sequence is actually signed...
347        if seqid > 2147483647:
348            seqid = -2147483648 - (2147483648 - seqid)
349        name = binary_to_str(self.__readBinary())
350        return (name, type, seqid)
351
352    def readMessageEnd(self):
353        assert self.state == CLEAR
354        assert len(self.__structs) == 0
355
356    def readStructBegin(self):
357        assert self.state in (CLEAR, CONTAINER_READ, VALUE_READ), self.state
358        self.__structs.append((self.state, self.__last_fid))
359        self.state = FIELD_READ
360        self.__last_fid = 0
361
362    def readStructEnd(self):
363        assert self.state == FIELD_READ
364        self.state, self.__last_fid = self.__structs.pop()
365
366    def readCollectionBegin(self):
367        assert self.state in (VALUE_READ, CONTAINER_READ), self.state
368        size_type = self.__readUByte()
369        size = size_type >> 4
370        type = self.__getTType(size_type)
371        if size == 15:
372            size = self.__readSize()
373        self._check_container_length(size)
374        self.__containers.append(self.state)
375        self.state = CONTAINER_READ
376        return type, size
377    readSetBegin = readCollectionBegin
378    readListBegin = readCollectionBegin
379
380    def readMapBegin(self):
381        assert self.state in (VALUE_READ, CONTAINER_READ), self.state
382        size = self.__readSize()
383        self._check_container_length(size)
384        types = 0
385        if size > 0:
386            types = self.__readUByte()
387        vtype = self.__getTType(types)
388        ktype = self.__getTType(types >> 4)
389        self.__containers.append(self.state)
390        self.state = CONTAINER_READ
391        return (ktype, vtype, size)
392
393    def readCollectionEnd(self):
394        assert self.state == CONTAINER_READ, self.state
395        self.state = self.__containers.pop()
396    readSetEnd = readCollectionEnd
397    readListEnd = readCollectionEnd
398    readMapEnd = readCollectionEnd
399
400    def readBool(self):
401        if self.state == BOOL_READ:
402            return self.__bool_value == CompactType.TRUE
403        elif self.state == CONTAINER_READ:
404            return self.__readByte() == CompactType.TRUE
405        else:
406            raise AssertionError("Invalid state in compact protocol: %d" %
407                                 self.state)
408
409    readByte = reader(__readByte)
410    __readI16 = __readZigZag
411    readI16 = reader(__readZigZag)
412    readI32 = reader(__readZigZag)
413    readI64 = reader(__readZigZag)
414
415    @reader
416    def readDouble(self):
417        buff = self.trans.readAll(8)
418        val, = unpack('<d', buff)
419        return val
420
421    def __readBinary(self):
422        size = self.__readSize()
423        self._check_string_length(size)
424        return self.trans.readAll(size)
425    readBinary = reader(__readBinary)
426
427    def __getTType(self, byte):
428        return TTYPES[byte & 0x0f]
429
430
431class TCompactProtocolFactory(TProtocolFactory):
432    def __init__(self,
433                 string_length_limit=None,
434                 container_length_limit=None):
435        self.string_length_limit = string_length_limit
436        self.container_length_limit = container_length_limit
437
438    def getProtocol(self, trans):
439        return TCompactProtocol(trans,
440                                self.string_length_limit,
441                                self.container_length_limit)
442
443
444class TCompactProtocolAccelerated(TCompactProtocol):
445    """C-Accelerated version of TCompactProtocol.
446
447    This class does not override any of TCompactProtocol's methods,
448    but the generated code recognizes it directly and will call into
449    our C module to do the encoding, bypassing this object entirely.
450    We inherit from TCompactProtocol so that the normal TCompactProtocol
451    encoding can happen if the fastbinary module doesn't work for some
452    reason.
453    To disable this behavior, pass fallback=False constructor argument.
454
455    In order to take advantage of the C module, just use
456    TCompactProtocolAccelerated instead of TCompactProtocol.
457    """
458    pass
459
460    def __init__(self, *args, **kwargs):
461        fallback = kwargs.pop('fallback', True)
462        super(TCompactProtocolAccelerated, self).__init__(*args, **kwargs)
463        try:
464            from thrift.protocol import fastbinary
465        except ImportError:
466            if not fallback:
467                raise
468        else:
469            self._fast_decode = fastbinary.decode_compact
470            self._fast_encode = fastbinary.encode_compact
471
472
473class TCompactProtocolAcceleratedFactory(TProtocolFactory):
474    def __init__(self,
475                 string_length_limit=None,
476                 container_length_limit=None,
477                 fallback=True):
478        self.string_length_limit = string_length_limit
479        self.container_length_limit = container_length_limit
480        self._fallback = fallback
481
482    def getProtocol(self, trans):
483        return TCompactProtocolAccelerated(
484            trans,
485            string_length_limit=self.string_length_limit,
486            container_length_limit=self.container_length_limit,
487            fallback=self._fallback)
488