1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from thrift.Thrift import TException, TType, TFrozenDict
21from thrift.transport.TTransport import TTransportException
22from ..compat import binary_to_str, str_to_binary
23
24import six
25import sys
26from itertools import islice
27from six.moves import zip
28
29
30class TProtocolException(TException):
31    """Custom Protocol Exception class"""
32
33    UNKNOWN = 0
34    INVALID_DATA = 1
35    NEGATIVE_SIZE = 2
36    SIZE_LIMIT = 3
37    BAD_VERSION = 4
38    NOT_IMPLEMENTED = 5
39    DEPTH_LIMIT = 6
40    INVALID_PROTOCOL = 7
41
42    def __init__(self, type=UNKNOWN, message=None):
43        TException.__init__(self, message)
44        self.type = type
45
46
47class TProtocolBase(object):
48    """Base class for Thrift protocol driver."""
49
50    def __init__(self, trans):
51        self.trans = trans
52        self._fast_decode = None
53        self._fast_encode = None
54
55    @staticmethod
56    def _check_length(limit, length):
57        if length < 0:
58            raise TTransportException(TTransportException.NEGATIVE_SIZE,
59                                      'Negative length: %d' % length)
60        if limit is not None and length > limit:
61            raise TTransportException(TTransportException.SIZE_LIMIT,
62                                      'Length exceeded max allowed: %d' % limit)
63
64    def writeMessageBegin(self, name, ttype, seqid):
65        pass
66
67    def writeMessageEnd(self):
68        pass
69
70    def writeStructBegin(self, name):
71        pass
72
73    def writeStructEnd(self):
74        pass
75
76    def writeFieldBegin(self, name, ttype, fid):
77        pass
78
79    def writeFieldEnd(self):
80        pass
81
82    def writeFieldStop(self):
83        pass
84
85    def writeMapBegin(self, ktype, vtype, size):
86        pass
87
88    def writeMapEnd(self):
89        pass
90
91    def writeListBegin(self, etype, size):
92        pass
93
94    def writeListEnd(self):
95        pass
96
97    def writeSetBegin(self, etype, size):
98        pass
99
100    def writeSetEnd(self):
101        pass
102
103    def writeBool(self, bool_val):
104        pass
105
106    def writeByte(self, byte):
107        pass
108
109    def writeI16(self, i16):
110        pass
111
112    def writeI32(self, i32):
113        pass
114
115    def writeI64(self, i64):
116        pass
117
118    def writeDouble(self, dub):
119        pass
120
121    def writeString(self, str_val):
122        self.writeBinary(str_to_binary(str_val))
123
124    def writeBinary(self, str_val):
125        pass
126
127    def writeUtf8(self, str_val):
128        self.writeString(str_val.encode('utf8'))
129
130    def readMessageBegin(self):
131        pass
132
133    def readMessageEnd(self):
134        pass
135
136    def readStructBegin(self):
137        pass
138
139    def readStructEnd(self):
140        pass
141
142    def readFieldBegin(self):
143        pass
144
145    def readFieldEnd(self):
146        pass
147
148    def readMapBegin(self):
149        pass
150
151    def readMapEnd(self):
152        pass
153
154    def readListBegin(self):
155        pass
156
157    def readListEnd(self):
158        pass
159
160    def readSetBegin(self):
161        pass
162
163    def readSetEnd(self):
164        pass
165
166    def readBool(self):
167        pass
168
169    def readByte(self):
170        pass
171
172    def readI16(self):
173        pass
174
175    def readI32(self):
176        pass
177
178    def readI64(self):
179        pass
180
181    def readDouble(self):
182        pass
183
184    def readString(self):
185        return binary_to_str(self.readBinary())
186
187    def readBinary(self):
188        pass
189
190    def readUtf8(self):
191        return self.readString().decode('utf8')
192
193    def skip(self, ttype):
194        if ttype == TType.BOOL:
195            self.readBool()
196        elif ttype == TType.BYTE:
197            self.readByte()
198        elif ttype == TType.I16:
199            self.readI16()
200        elif ttype == TType.I32:
201            self.readI32()
202        elif ttype == TType.I64:
203            self.readI64()
204        elif ttype == TType.DOUBLE:
205            self.readDouble()
206        elif ttype == TType.STRING:
207            self.readString()
208        elif ttype == TType.STRUCT:
209            name = self.readStructBegin()
210            while True:
211                (name, ttype, id) = self.readFieldBegin()
212                if ttype == TType.STOP:
213                    break
214                self.skip(ttype)
215                self.readFieldEnd()
216            self.readStructEnd()
217        elif ttype == TType.MAP:
218            (ktype, vtype, size) = self.readMapBegin()
219            for i in range(size):
220                self.skip(ktype)
221                self.skip(vtype)
222            self.readMapEnd()
223        elif ttype == TType.SET:
224            (etype, size) = self.readSetBegin()
225            for i in range(size):
226                self.skip(etype)
227            self.readSetEnd()
228        elif ttype == TType.LIST:
229            (etype, size) = self.readListBegin()
230            for i in range(size):
231                self.skip(etype)
232            self.readListEnd()
233        else:
234            raise TProtocolException(
235                TProtocolException.INVALID_DATA,
236                "invalid TType")
237
238    # tuple of: ( 'reader method' name, is_container bool, 'writer_method' name )
239    _TTYPE_HANDLERS = (
240        (None, None, False),  # 0 TType.STOP
241        (None, None, False),  # 1 TType.VOID # TODO: handle void?
242        ('readBool', 'writeBool', False),  # 2 TType.BOOL
243        ('readByte', 'writeByte', False),  # 3 TType.BYTE and I08
244        ('readDouble', 'writeDouble', False),  # 4 TType.DOUBLE
245        (None, None, False),  # 5 undefined
246        ('readI16', 'writeI16', False),  # 6 TType.I16
247        (None, None, False),  # 7 undefined
248        ('readI32', 'writeI32', False),  # 8 TType.I32
249        (None, None, False),  # 9 undefined
250        ('readI64', 'writeI64', False),  # 10 TType.I64
251        ('readString', 'writeString', False),  # 11 TType.STRING and UTF7
252        ('readContainerStruct', 'writeContainerStruct', True),  # 12 *.STRUCT
253        ('readContainerMap', 'writeContainerMap', True),  # 13 TType.MAP
254        ('readContainerSet', 'writeContainerSet', True),  # 14 TType.SET
255        ('readContainerList', 'writeContainerList', True),  # 15 TType.LIST
256        (None, None, False),  # 16 TType.UTF8 # TODO: handle utf8 types?
257        (None, None, False)  # 17 TType.UTF16 # TODO: handle utf16 types?
258    )
259
260    def _ttype_handlers(self, ttype, spec):
261        if spec == 'BINARY':
262            if ttype != TType.STRING:
263                raise TProtocolException(type=TProtocolException.INVALID_DATA,
264                                         message='Invalid binary field type %d' % ttype)
265            return ('readBinary', 'writeBinary', False)
266        if sys.version_info[0] == 2 and spec == 'UTF8':
267            if ttype != TType.STRING:
268                raise TProtocolException(type=TProtocolException.INVALID_DATA,
269                                         message='Invalid string field type %d' % ttype)
270            return ('readUtf8', 'writeUtf8', False)
271        return self._TTYPE_HANDLERS[ttype] if ttype < len(self._TTYPE_HANDLERS) else (None, None, False)
272
273    def _read_by_ttype(self, ttype, spec, espec):
274        reader_name, _, is_container = self._ttype_handlers(ttype, espec)
275        if reader_name is None:
276            raise TProtocolException(type=TProtocolException.INVALID_DATA,
277                                     message='Invalid type %d' % (ttype))
278        reader_func = getattr(self, reader_name)
279        read = (lambda: reader_func(espec)) if is_container else reader_func
280        while True:
281            yield read()
282
283    def readFieldByTType(self, ttype, spec):
284        return next(self._read_by_ttype(ttype, spec, spec))
285
286    def readContainerList(self, spec):
287        ttype, tspec, is_immutable = spec
288        (list_type, list_len) = self.readListBegin()
289        # TODO: compare types we just decoded with thrift_spec
290        elems = islice(self._read_by_ttype(ttype, spec, tspec), list_len)
291        results = (tuple if is_immutable else list)(elems)
292        self.readListEnd()
293        return results
294
295    def readContainerSet(self, spec):
296        ttype, tspec, is_immutable = spec
297        (set_type, set_len) = self.readSetBegin()
298        # TODO: compare types we just decoded with thrift_spec
299        elems = islice(self._read_by_ttype(ttype, spec, tspec), set_len)
300        results = (frozenset if is_immutable else set)(elems)
301        self.readSetEnd()
302        return results
303
304    def readContainerStruct(self, spec):
305        (obj_class, obj_spec) = spec
306
307        # If obj_class.read is a classmethod (e.g. in frozen structs),
308        # call it as such.
309        if getattr(obj_class.read, '__self__', None) is obj_class:
310            obj = obj_class.read(self)
311        else:
312            obj = obj_class()
313            obj.read(self)
314        return obj
315
316    def readContainerMap(self, spec):
317        ktype, kspec, vtype, vspec, is_immutable = spec
318        (map_ktype, map_vtype, map_len) = self.readMapBegin()
319        # TODO: compare types we just decoded with thrift_spec and
320        # abort/skip if types disagree
321        keys = self._read_by_ttype(ktype, spec, kspec)
322        vals = self._read_by_ttype(vtype, spec, vspec)
323        keyvals = islice(zip(keys, vals), map_len)
324        results = (TFrozenDict if is_immutable else dict)(keyvals)
325        self.readMapEnd()
326        return results
327
328    def readStruct(self, obj, thrift_spec, is_immutable=False):
329        if is_immutable:
330            fields = {}
331        self.readStructBegin()
332        while True:
333            (fname, ftype, fid) = self.readFieldBegin()
334            if ftype == TType.STOP:
335                break
336            try:
337                field = thrift_spec[fid]
338            except IndexError:
339                self.skip(ftype)
340            else:
341                if field is not None and ftype == field[1]:
342                    fname = field[2]
343                    fspec = field[3]
344                    val = self.readFieldByTType(ftype, fspec)
345                    if is_immutable:
346                        fields[fname] = val
347                    else:
348                        setattr(obj, fname, val)
349                else:
350                    self.skip(ftype)
351            self.readFieldEnd()
352        self.readStructEnd()
353        if is_immutable:
354            return obj(**fields)
355
356    def writeContainerStruct(self, val, spec):
357        val.write(self)
358
359    def writeContainerList(self, val, spec):
360        ttype, tspec, _ = spec
361        self.writeListBegin(ttype, len(val))
362        for _ in self._write_by_ttype(ttype, val, spec, tspec):
363            pass
364        self.writeListEnd()
365
366    def writeContainerSet(self, val, spec):
367        ttype, tspec, _ = spec
368        self.writeSetBegin(ttype, len(val))
369        for _ in self._write_by_ttype(ttype, val, spec, tspec):
370            pass
371        self.writeSetEnd()
372
373    def writeContainerMap(self, val, spec):
374        ktype, kspec, vtype, vspec, _ = spec
375        self.writeMapBegin(ktype, vtype, len(val))
376        for _ in zip(self._write_by_ttype(ktype, six.iterkeys(val), spec, kspec),
377                     self._write_by_ttype(vtype, six.itervalues(val), spec, vspec)):
378            pass
379        self.writeMapEnd()
380
381    def writeStruct(self, obj, thrift_spec):
382        self.writeStructBegin(obj.__class__.__name__)
383        for field in thrift_spec:
384            if field is None:
385                continue
386            fname = field[2]
387            val = getattr(obj, fname)
388            if val is None:
389                # skip writing out unset fields
390                continue
391            fid = field[0]
392            ftype = field[1]
393            fspec = field[3]
394            self.writeFieldBegin(fname, ftype, fid)
395            self.writeFieldByTType(ftype, val, fspec)
396            self.writeFieldEnd()
397        self.writeFieldStop()
398        self.writeStructEnd()
399
400    def _write_by_ttype(self, ttype, vals, spec, espec):
401        _, writer_name, is_container = self._ttype_handlers(ttype, espec)
402        writer_func = getattr(self, writer_name)
403        write = (lambda v: writer_func(v, espec)) if is_container else writer_func
404        for v in vals:
405            yield write(v)
406
407    def writeFieldByTType(self, ttype, val, spec):
408        next(self._write_by_ttype(ttype, [val], spec, spec))
409
410
411def checkIntegerLimits(i, bits):
412    if bits == 8 and (i < -128 or i > 127):
413        raise TProtocolException(TProtocolException.INVALID_DATA,
414                                 "i8 requires -128 <= number <= 127")
415    elif bits == 16 and (i < -32768 or i > 32767):
416        raise TProtocolException(TProtocolException.INVALID_DATA,
417                                 "i16 requires -32768 <= number <= 32767")
418    elif bits == 32 and (i < -2147483648 or i > 2147483647):
419        raise TProtocolException(TProtocolException.INVALID_DATA,
420                                 "i32 requires -2147483648 <= number <= 2147483647")
421    elif bits == 64 and (i < -9223372036854775808 or i > 9223372036854775807):
422        raise TProtocolException(TProtocolException.INVALID_DATA,
423                                 "i64 requires -9223372036854775808 <= number <= 9223372036854775807")
424
425
426class TProtocolFactory(object):
427    def getProtocol(self, trans):
428        pass
429