1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22from ThriftTest.ttypes import (
23    Bonk,
24    Bools,
25    LargeDeltas,
26    ListBonks,
27    NestedListsBonk,
28    NestedListsI32x2,
29    NestedListsI32x3,
30    NestedMixedx2,
31    Numberz,
32    VersioningTestV1,
33    VersioningTestV2,
34    Xtruct,
35    Xtruct2,
36)
37
38from Recursive.ttypes import RecTree
39from Recursive.ttypes import RecList
40from Recursive.ttypes import CoRec
41from Recursive.ttypes import CoRec2
42from Recursive.ttypes import VectorTest
43from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty
44from thrift.transport import TTransport
45from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
46from thrift.TSerialization import serialize, deserialize
47import sys
48import unittest
49
50
51class AbstractTest(unittest.TestCase):
52
53    def setUp(self):
54        self.v1obj = VersioningTestV1(
55            begin_in_both=12345,
56            old_string='aaa',
57            end_in_both=54321,
58        )
59
60        self.v2obj = VersioningTestV2(
61            begin_in_both=12345,
62            newint=1,
63            newbyte=2,
64            newshort=3,
65            newlong=4,
66            newdouble=5.0,
67            newstruct=Bonk(message="Hello!", type=123),
68            newlist=[7, 8, 9],
69            newset=set([42, 1, 8]),
70            newmap={1: 2, 2: 3},
71            newstring="Hola!",
72            end_in_both=54321,
73        )
74
75        self.bools = Bools(im_true=True, im_false=False)
76        self.bools_flipped = Bools(im_true=False, im_false=True)
77
78        self.large_deltas = LargeDeltas(
79            b1=self.bools,
80            b10=self.bools_flipped,
81            b100=self.bools,
82            check_true=True,
83            b1000=self.bools_flipped,
84            check_false=False,
85            vertwo2000=VersioningTestV2(newstruct=Bonk(message='World!', type=314)),
86            a_set2500=set(['lazy', 'brown', 'cow']),
87            vertwo3000=VersioningTestV2(newset=set([2, 3, 5, 7, 11])),
88            big_numbers=[2 ** 8, 2 ** 16, 2 ** 31 - 1, -(2 ** 31 - 1)]
89        )
90
91        self.compact_struct = CompactProtoTestStruct(
92            a_byte=127,
93            a_i16=32000,
94            a_i32=1000000000,
95            a_i64=0xffffffffff,
96            a_double=5.6789,
97            a_string="my string",
98            true_field=True,
99            false_field=False,
100            empty_struct_field=Empty(),
101            byte_list=[-127, -1, 0, 1, 127],
102            i16_list=[-1, 0, 1, 0x7fff],
103            i32_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0x7fffffff],
104            i64_list=[-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff],
105            double_list=[0.1, 0.2, 0.3],
106            string_list=["first", "second", "third"],
107            boolean_list=[True, True, True, False, False, False],
108            struct_list=[Empty(), Empty()],
109            byte_set=set([-127, -1, 0, 1, 127]),
110            i16_set=set([-1, 0, 1, 0x7fff]),
111            i32_set=set([1, 2, 3]),
112            i64_set=set([-1, 0, 0xff, 0xffff, 0xffffff, 0xffffffff, 0xffffffffff, 0xffffffffffff, 0xffffffffffffff, 0x7fffffffffffffff]),
113            double_set=set([0.1, 0.2, 0.3]),
114            string_set=set(["first", "second", "third"]),
115            boolean_set=set([True, False]),
116            # struct_set=set([Empty()]), # unhashable instance
117            byte_byte_map={1: 2},
118            i16_byte_map={1: 1, -1: 1, 0x7fff: 1},
119            i32_byte_map={1: 1, -1: 1, 0x7fffffff: 1},
120            i64_byte_map={0: 1, 1: 1, -1: 1, 0x7fffffffffffffff: 1},
121            double_byte_map={-1.1: 1, 1.1: 1},
122            string_byte_map={"first": 1, "second": 2, "third": 3, "": 0},
123            boolean_byte_map={True: 1, False: 0},
124            byte_i16_map={1: 1, 2: -1, 3: 0x7fff},
125            byte_i32_map={1: 1, 2: -1, 3: 0x7fffffff},
126            byte_i64_map={1: 1, 2: -1, 3: 0x7fffffffffffffff},
127            byte_double_map={1: 0.1, 2: -0.1, 3: 1000000.1},
128            byte_string_map={1: "", 2: "blah", 3: "loooooooooooooong string"},
129            byte_boolean_map={1: True, 2: False},
130            # list_byte_map # unhashable
131            # set_byte_map={set([1, 2, 3]) : 1, set([0, 1]) : 2, set([]) : 0}, # unhashable
132            # map_byte_map # unhashable
133            byte_map_map={0: {}, 1: {1: 1}, 2: {1: 1, 2: 2}},
134            byte_set_map={0: set([]), 1: set([1]), 2: set([1, 2])},
135            byte_list_map={0: [], 1: [1], 2: [1, 2]},
136        )
137
138        self.nested_lists_i32x2 = NestedListsI32x2(
139            [
140                [1, 1, 2],
141                [2, 7, 9],
142                [3, 5, 8]
143            ]
144        )
145
146        self.nested_lists_i32x3 = NestedListsI32x3(
147            [
148                [
149                    [2, 7, 9],
150                    [3, 5, 8]
151                ],
152                [
153                    [1, 1, 2],
154                    [1, 4, 9]
155                ]
156            ]
157        )
158
159        self.nested_mixedx2 = NestedMixedx2(int_set_list=[
160            set([1, 2, 3]),
161            set([1, 4, 9]),
162            set([1, 2, 3, 5, 8, 13, 21]),
163            set([-1, 0, 1])
164        ],
165            # note, the sets below are sets of chars, since the strings are iterated
166            map_int_strset={10: set('abc'), 20: set('def'), 30: set('GHI')},
167            map_int_strset_list=[
168                {10: set('abc'), 20: set('def'), 30: set('GHI')},
169                {100: set('lmn'), 200: set('opq'), 300: set('RST')},
170                {1000: set('uvw'), 2000: set('wxy'), 3000: set('XYZ')}]
171        )
172
173        self.nested_lists_bonk = NestedListsBonk(
174            [
175                [
176                    [
177                        Bonk(message='inner A first', type=1),
178                        Bonk(message='inner A second', type=1)
179                    ],
180                    [
181                        Bonk(message='inner B first', type=2),
182                        Bonk(message='inner B second', type=2)
183                    ]
184                ]
185            ]
186        )
187
188        self.list_bonks = ListBonks(
189            [
190                Bonk(message='inner A', type=1),
191                Bonk(message='inner B', type=2),
192                Bonk(message='inner C', type=0)
193            ]
194        )
195
196    def _serialize(self, obj):
197        trans = TTransport.TMemoryBuffer()
198        prot = self.protocol_factory.getProtocol(trans)
199        obj.write(prot)
200        return trans.getvalue()
201
202    def _deserialize(self, objtype, data):
203        prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
204        ret = objtype()
205        ret.read(prot)
206        return ret
207
208    def testForwards(self):
209        obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
210        self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
211        self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
212
213    def testBackwards(self):
214        obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
215        self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
216        self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
217
218    def testSerializeV1(self):
219        obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
220        self.assertEquals(obj, self.v1obj)
221
222    def testSerializeV2(self):
223        obj = self._deserialize(VersioningTestV2, self._serialize(self.v2obj))
224        self.assertEquals(obj, self.v2obj)
225
226    def testBools(self):
227        self.assertNotEquals(self.bools, self.bools_flipped)
228        self.assertNotEquals(self.bools, self.v1obj)
229        obj = self._deserialize(Bools, self._serialize(self.bools))
230        self.assertEquals(obj, self.bools)
231        obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
232        self.assertEquals(obj, self.bools_flipped)
233        rep = repr(self.bools)
234        self.assertTrue(len(rep) > 0)
235
236    def testLargeDeltas(self):
237        # test large field deltas (meaningful in CompactProto only)
238        obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
239        self.assertEquals(obj, self.large_deltas)
240        rep = repr(self.large_deltas)
241        self.assertTrue(len(rep) > 0)
242
243    def testNestedListsI32x2(self):
244        obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
245        self.assertEquals(obj, self.nested_lists_i32x2)
246        rep = repr(self.nested_lists_i32x2)
247        self.assertTrue(len(rep) > 0)
248
249    def testNestedListsI32x3(self):
250        obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
251        self.assertEquals(obj, self.nested_lists_i32x3)
252        rep = repr(self.nested_lists_i32x3)
253        self.assertTrue(len(rep) > 0)
254
255    def testNestedMixedx2(self):
256        obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
257        self.assertEquals(obj, self.nested_mixedx2)
258        rep = repr(self.nested_mixedx2)
259        self.assertTrue(len(rep) > 0)
260
261    def testNestedListsBonk(self):
262        obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
263        self.assertEquals(obj, self.nested_lists_bonk)
264        rep = repr(self.nested_lists_bonk)
265        self.assertTrue(len(rep) > 0)
266
267    def testListBonks(self):
268        obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
269        self.assertEquals(obj, self.list_bonks)
270        rep = repr(self.list_bonks)
271        self.assertTrue(len(rep) > 0)
272
273    def testCompactStruct(self):
274        # test large field deltas (meaningful in CompactProto only)
275        obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
276        self.assertEquals(obj, self.compact_struct)
277        rep = repr(self.compact_struct)
278        self.assertTrue(len(rep) > 0)
279
280    def testIntegerLimits(self):
281        if (sys.version_info[0] == 2 and sys.version_info[1] <= 6):
282            print('Skipping testIntegerLimits for Python 2.6')
283            return
284        bad_values = [CompactProtoTestStruct(a_byte=128), CompactProtoTestStruct(a_byte=-129),
285                      CompactProtoTestStruct(a_i16=32768), CompactProtoTestStruct(a_i16=-32769),
286                      CompactProtoTestStruct(a_i32=2147483648), CompactProtoTestStruct(a_i32=-2147483649),
287                      CompactProtoTestStruct(a_i64=9223372036854775808), CompactProtoTestStruct(a_i64=-9223372036854775809)
288                      ]
289
290        for value in bad_values:
291            self.assertRaises(Exception, self._serialize, value)
292
293    def testRecTree(self):
294        """Ensure recursive tree node can be created."""
295        children = []
296        for idx in range(1, 5):
297            node = RecTree(item=idx, children=None)
298            children.append(node)
299
300        parent = RecTree(item=0, children=children)
301        serde_parent = self._deserialize(RecTree, self._serialize(parent))
302        self.assertEquals(0, serde_parent.item)
303        self.assertEquals(4, len(serde_parent.children))
304        for child in serde_parent.children:
305            # Cannot use assertIsInstance in python 2.6?
306            self.assertTrue(isinstance(child, RecTree))
307
308    def _buildLinkedList(self):
309        head = cur = RecList(item=0)
310        for idx in range(1, 5):
311            node = RecList(item=idx)
312            cur.nextitem = node
313            cur = node
314        return head
315
316    def _collapseLinkedList(self, head):
317        out_list = []
318        cur = head
319        while cur is not None:
320            out_list.append(cur.item)
321            cur = cur.nextitem
322        return out_list
323
324    def testRecList(self):
325        """Ensure recursive linked list can be created."""
326        rec_list = self._buildLinkedList()
327        serde_list = self._deserialize(RecList, self._serialize(rec_list))
328        out_list = self._collapseLinkedList(serde_list)
329        self.assertEquals([0, 1, 2, 3, 4], out_list)
330
331    def testCoRec(self):
332        """Ensure co-recursive structures can be created."""
333        item1 = CoRec()
334        item2 = CoRec2()
335
336        item1.other = item2
337        item2.other = item1
338
339        # NOTE [econner724,2017-06-21]: These objects cannot be serialized as serialization
340        # results in an infinite loop. fbthrift also suffers from this
341        # problem.
342
343    def testRecVector(self):
344        """Ensure a list of recursive nodes can be created."""
345        mylist = [self._buildLinkedList(), self._buildLinkedList()]
346        myvec = VectorTest(lister=mylist)
347
348        serde_vec = self._deserialize(VectorTest, self._serialize(myvec))
349        golden_list = [0, 1, 2, 3, 4]
350        for cur_list in serde_vec.lister:
351            out_list = self._collapseLinkedList(cur_list)
352            self.assertEqual(golden_list, out_list)
353
354
355class NormalBinaryTest(AbstractTest):
356    protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
357
358
359class AcceleratedBinaryTest(AbstractTest):
360    protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)
361
362
363class CompactProtocolTest(AbstractTest):
364    protocol_factory = TCompactProtocol.TCompactProtocolFactory()
365
366
367class AcceleratedCompactTest(AbstractTest):
368    protocol_factory = TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)
369
370
371class JSONProtocolTest(AbstractTest):
372    protocol_factory = TJSONProtocol.TJSONProtocolFactory()
373
374
375class AcceleratedFramedTest(unittest.TestCase):
376    def testSplit(self):
377        """Test FramedTransport and BinaryProtocolAccelerated
378
379        Tests that TBinaryProtocolAccelerated and TFramedTransport
380        play nicely together when a read spans a frame"""
381
382        protocol_factory = TBinaryProtocol.TBinaryProtocolAcceleratedFactory()
383        bigstring = "".join(chr(byte) for byte in range(ord("a"), ord("z") + 1))
384
385        databuf = TTransport.TMemoryBuffer()
386        prot = protocol_factory.getProtocol(databuf)
387        prot.writeI32(42)
388        prot.writeString(bigstring)
389        prot.writeI16(24)
390        data = databuf.getvalue()
391        cutpoint = len(data) // 2
392        parts = [data[:cutpoint], data[cutpoint:]]
393
394        framed_buffer = TTransport.TMemoryBuffer()
395        framed_writer = TTransport.TFramedTransport(framed_buffer)
396        for part in parts:
397            framed_writer.write(part)
398            framed_writer.flush()
399        self.assertEquals(len(framed_buffer.getvalue()), len(data) + 8)
400
401        # Recreate framed_buffer so we can read from it.
402        framed_buffer = TTransport.TMemoryBuffer(framed_buffer.getvalue())
403        framed_reader = TTransport.TFramedTransport(framed_buffer)
404        prot = protocol_factory.getProtocol(framed_reader)
405        self.assertEqual(prot.readI32(), 42)
406        self.assertEqual(prot.readString(), bigstring)
407        self.assertEqual(prot.readI16(), 24)
408
409
410class SerializersTest(unittest.TestCase):
411
412    def testSerializeThenDeserialize(self):
413        obj = Xtruct2(i32_thing=1,
414                      struct_thing=Xtruct(string_thing="foo"))
415
416        s1 = serialize(obj)
417        for i in range(10):
418            self.assertEquals(s1, serialize(obj))
419            objcopy = Xtruct2()
420            deserialize(objcopy, serialize(obj))
421            self.assertEquals(obj, objcopy)
422
423        obj = Xtruct(string_thing="bar")
424        objcopy = Xtruct()
425        deserialize(objcopy, serialize(obj))
426        self.assertEquals(obj, objcopy)
427
428        # test booleans
429        obj = Bools(im_true=True, im_false=False)
430        objcopy = Bools()
431        deserialize(objcopy, serialize(obj))
432        self.assertEquals(obj, objcopy)
433
434        # test enums
435        def _enumerate_enum(enum_class):
436            if hasattr(enum_class, '_VALUES_TO_NAMES'):
437                # old-style enums
438                for num, name in enum_class._VALUES_TO_NAMES.items():
439                    yield (num, name)
440            else:
441                # assume Python 3.4+ IntEnum-based
442                from enum import IntEnum
443                self.assertTrue((issubclass(enum_class, IntEnum)))
444                for num in enum_class:
445                    yield (num.value, num.name)
446
447        for num, name in _enumerate_enum(Numberz):
448            obj = Bonk(message='enum Numberz value %d is string %s' % (num, name), type=num)
449            objcopy = Bonk()
450            deserialize(objcopy, serialize(obj))
451            self.assertEquals(obj, objcopy)
452
453
454def suite():
455    suite = unittest.TestSuite()
456    loader = unittest.TestLoader()
457
458    suite.addTest(loader.loadTestsFromTestCase(NormalBinaryTest))
459    suite.addTest(loader.loadTestsFromTestCase(AcceleratedBinaryTest))
460    suite.addTest(loader.loadTestsFromTestCase(AcceleratedCompactTest))
461    suite.addTest(loader.loadTestsFromTestCase(CompactProtocolTest))
462    suite.addTest(loader.loadTestsFromTestCase(JSONProtocolTest))
463    suite.addTest(loader.loadTestsFromTestCase(AcceleratedFramedTest))
464    suite.addTest(loader.loadTestsFromTestCase(SerializersTest))
465    return suite
466
467
468if __name__ == "__main__":
469    unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
470