1#!/usr/bin/env python 2 3# 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20# 21 22from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2 23from thrift.protocol import TJSONProtocol 24from thrift.transport import TTransport 25 26import json 27import unittest 28 29 30class SimpleJSONProtocolTest(unittest.TestCase): 31 protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory() 32 33 def _assertDictEqual(self, a, b, msg=None): 34 if hasattr(self, 'assertDictEqual'): 35 # assertDictEqual only in Python 2.7. Depends on your machine. 36 self.assertDictEqual(a, b, msg) 37 return 38 39 # Substitute implementation not as good as unittest library's 40 self.assertEquals(len(a), len(b), msg) 41 for k, v in a.iteritems(): 42 self.assertTrue(k in b, msg) 43 self.assertEquals(b.get(k), v, msg) 44 45 def _serialize(self, obj): 46 trans = TTransport.TMemoryBuffer() 47 prot = self.protocol_factory.getProtocol(trans) 48 obj.write(prot) 49 return trans.getvalue() 50 51 def _deserialize(self, objtype, data): 52 prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data)) 53 ret = objtype() 54 ret.read(prot) 55 return ret 56 57 def testWriteOnly(self): 58 self.assertRaises(NotImplementedError, 59 self._deserialize, VersioningTestV1, b'{}') 60 61 def testSimpleMessage(self): 62 v1obj = VersioningTestV1( 63 begin_in_both=12345, 64 old_string='aaa', 65 end_in_both=54321) 66 expected = dict(begin_in_both=v1obj.begin_in_both, 67 old_string=v1obj.old_string, 68 end_in_both=v1obj.end_in_both) 69 actual = json.loads(self._serialize(v1obj).decode('ascii')) 70 71 self._assertDictEqual(expected, actual) 72 73 def testComplicated(self): 74 v2obj = VersioningTestV2( 75 begin_in_both=12345, 76 newint=1, 77 newbyte=2, 78 newshort=3, 79 newlong=4, 80 newdouble=5.0, 81 newstruct=Bonk(message="Hello!", type=123), 82 newlist=[7, 8, 9], 83 newset=set([42, 1, 8]), 84 newmap={1: 2, 2: 3}, 85 newstring="Hola!", 86 end_in_both=54321) 87 expected = dict(begin_in_both=v2obj.begin_in_both, 88 newint=v2obj.newint, 89 newbyte=v2obj.newbyte, 90 newshort=v2obj.newshort, 91 newlong=v2obj.newlong, 92 newdouble=v2obj.newdouble, 93 newstruct=dict(message=v2obj.newstruct.message, 94 type=v2obj.newstruct.type), 95 newlist=v2obj.newlist, 96 newset=list(v2obj.newset), 97 newmap=v2obj.newmap, 98 newstring=v2obj.newstring, 99 end_in_both=v2obj.end_in_both) 100 101 # Need to load/dump because map keys get escaped. 102 expected = json.loads(json.dumps(expected)) 103 actual = json.loads(self._serialize(v2obj).decode('ascii')) 104 self._assertDictEqual(expected, actual) 105 106 107if __name__ == '__main__': 108 unittest.main() 109