1#!/usr/bin/env python
2
3#
4# Licensed to the Apache Software Foundation (ASF) under one
5# or more contributor license agreements. See the NOTICE file
6# distributed with this work for additional information
7# regarding copyright ownership. The ASF licenses this file
8# to you under the Apache License, Version 2.0 (the
9# "License"); you may not use this file except in compliance
10# with the License. You may obtain a copy of the License at
11#
12#   http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing,
15# software distributed under the License is distributed on an
16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17# KIND, either express or implied. See the License for the
18# specific language governing permissions and limitations
19# under the License.
20#
21
22from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2
23from thrift.protocol import TJSONProtocol
24from thrift.transport import TTransport
25
26import json
27import unittest
28
29
30class SimpleJSONProtocolTest(unittest.TestCase):
31    protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()
32
33    def _assertDictEqual(self, a, b, msg=None):
34        if hasattr(self, 'assertDictEqual'):
35            # assertDictEqual only in Python 2.7. Depends on your machine.
36            self.assertDictEqual(a, b, msg)
37            return
38
39        # Substitute implementation not as good as unittest library's
40        self.assertEquals(len(a), len(b), msg)
41        for k, v in a.iteritems():
42            self.assertTrue(k in b, msg)
43            self.assertEquals(b.get(k), v, msg)
44
45    def _serialize(self, obj):
46        trans = TTransport.TMemoryBuffer()
47        prot = self.protocol_factory.getProtocol(trans)
48        obj.write(prot)
49        return trans.getvalue()
50
51    def _deserialize(self, objtype, data):
52        prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
53        ret = objtype()
54        ret.read(prot)
55        return ret
56
57    def testWriteOnly(self):
58        self.assertRaises(NotImplementedError,
59                          self._deserialize, VersioningTestV1, b'{}')
60
61    def testSimpleMessage(self):
62        v1obj = VersioningTestV1(
63            begin_in_both=12345,
64            old_string='aaa',
65            end_in_both=54321)
66        expected = dict(begin_in_both=v1obj.begin_in_both,
67                        old_string=v1obj.old_string,
68                        end_in_both=v1obj.end_in_both)
69        actual = json.loads(self._serialize(v1obj).decode('ascii'))
70
71        self._assertDictEqual(expected, actual)
72
73    def testComplicated(self):
74        v2obj = VersioningTestV2(
75            begin_in_both=12345,
76            newint=1,
77            newbyte=2,
78            newshort=3,
79            newlong=4,
80            newdouble=5.0,
81            newstruct=Bonk(message="Hello!", type=123),
82            newlist=[7, 8, 9],
83            newset=set([42, 1, 8]),
84            newmap={1: 2, 2: 3},
85            newstring="Hola!",
86            end_in_both=54321)
87        expected = dict(begin_in_both=v2obj.begin_in_both,
88                        newint=v2obj.newint,
89                        newbyte=v2obj.newbyte,
90                        newshort=v2obj.newshort,
91                        newlong=v2obj.newlong,
92                        newdouble=v2obj.newdouble,
93                        newstruct=dict(message=v2obj.newstruct.message,
94                                       type=v2obj.newstruct.type),
95                        newlist=v2obj.newlist,
96                        newset=list(v2obj.newset),
97                        newmap=v2obj.newmap,
98                        newstring=v2obj.newstring,
99                        end_in_both=v2obj.end_in_both)
100
101        # Need to load/dump because map keys get escaped.
102        expected = json.loads(json.dumps(expected))
103        actual = json.loads(self._serialize(v2obj).decode('ascii'))
104        self._assertDictEqual(expected, actual)
105
106
107if __name__ == '__main__':
108    unittest.main()
109