1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import unittest 21 22import _import_local_thrift # noqa 23from thrift.protocol.TBinaryProtocol import TBinaryProtocol 24from thrift.transport import TTransport 25 26 27def testNaked(type, data): 28 buf = TTransport.TMemoryBuffer() 29 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 30 protocol = TBinaryProtocol(transport) 31 if type.capitalize() == 'Byte': 32 protocol.writeByte(data) 33 34 if type.capitalize() == 'I16': 35 protocol.writeI16(data) 36 37 if type.capitalize() == 'I32': 38 protocol.writeI32(data) 39 40 if type.capitalize() == 'I64': 41 protocol.writeI64(data) 42 43 if type.capitalize() == 'String': 44 protocol.writeString(data) 45 46 if type.capitalize() == 'Double': 47 protocol.writeDouble(data) 48 49 if type.capitalize() == 'Binary': 50 protocol.writeBinary(data) 51 52 if type.capitalize() == 'Bool': 53 protocol.writeBool(data) 54 55 transport.flush() 56 data_r = buf.getvalue() 57 buf = TTransport.TMemoryBuffer(data_r) 58 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 59 protocol = TBinaryProtocol(transport) 60 if type.capitalize() == 'Byte': 61 return protocol.readByte() 62 63 if type.capitalize() == 'I16': 64 return protocol.readI16() 65 66 if type.capitalize() == 'I32': 67 return protocol.readI32() 68 69 if type.capitalize() == 'I64': 70 return protocol.readI64() 71 72 if type.capitalize() == 'String': 73 return protocol.readString() 74 75 if type.capitalize() == 'Double': 76 return protocol.readDouble() 77 78 if type.capitalize() == 'Binary': 79 return protocol.readBinary() 80 81 if type.capitalize() == 'Bool': 82 return protocol.readBool() 83 84 85def testField(type, data): 86 TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12} 87 buf = TTransport.TMemoryBuffer() 88 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 89 protocol = TBinaryProtocol(transport) 90 protocol.writeStructBegin('struct') 91 protocol.writeFieldBegin("field", TType[type.capitalize()], 10) 92 if type.capitalize() == 'Byte': 93 protocol.writeByte(data) 94 95 if type.capitalize() == 'I16': 96 protocol.writeI16(data) 97 98 if type.capitalize() == 'I32': 99 protocol.writeI32(data) 100 101 if type.capitalize() == 'I64': 102 protocol.writeI64(data) 103 104 if type.capitalize() == 'String': 105 protocol.writeString(data) 106 107 if type.capitalize() == 'Double': 108 protocol.writeDouble(data) 109 110 if type.capitalize() == 'Binary': 111 protocol.writeBinary(data) 112 113 if type.capitalize() == 'Bool': 114 protocol.writeBool(data) 115 116 protocol.writeFieldEnd() 117 protocol.writeStructEnd() 118 119 transport.flush() 120 data_r = buf.getvalue() 121 122 buf = TTransport.TMemoryBuffer(data_r) 123 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 124 protocol = TBinaryProtocol(transport) 125 protocol.readStructBegin() 126 protocol.readFieldBegin() 127 if type.capitalize() == 'Byte': 128 return protocol.readByte() 129 130 if type.capitalize() == 'I16': 131 return protocol.readI16() 132 133 if type.capitalize() == 'I32': 134 return protocol.readI32() 135 136 if type.capitalize() == 'I64': 137 return protocol.readI64() 138 139 if type.capitalize() == 'String': 140 return protocol.readString() 141 142 if type.capitalize() == 'Double': 143 return protocol.readDouble() 144 145 if type.capitalize() == 'Binary': 146 return protocol.readBinary() 147 148 if type.capitalize() == 'Bool': 149 return protocol.readBool() 150 151 protocol.readFieldEnd() 152 protocol.readStructEnd() 153 154 155def testMessage(data): 156 message = {} 157 message['name'] = data[0] 158 message['type'] = data[1] 159 message['seqid'] = data[2] 160 161 buf = TTransport.TMemoryBuffer() 162 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 163 protocol = TBinaryProtocol(transport) 164 protocol.writeMessageBegin(message['name'], message['type'], message['seqid']) 165 protocol.writeMessageEnd() 166 167 transport.flush() 168 data_r = buf.getvalue() 169 170 buf = TTransport.TMemoryBuffer(data_r) 171 transport = TTransport.TBufferedTransportFactory().getTransport(buf) 172 protocol = TBinaryProtocol(transport) 173 result = protocol.readMessageBegin() 174 protocol.readMessageEnd() 175 return result 176 177 178class TestTBinaryProtocol(unittest.TestCase): 179 180 def test_TBinaryProtocol_write_read(self): 181 try: 182 testNaked('Byte', 123) 183 for i in range(0, 128): 184 self.assertEqual(i, testField('Byte', i)) 185 self.assertEqual(-i, testField('Byte', -i)) 186 187 self.assertEqual(0, testNaked("I16", 0)) 188 self.assertEqual(1, testNaked("I16", 1)) 189 self.assertEqual(15000, testNaked("I16", 15000)) 190 self.assertEqual(0x7fff, testNaked("I16", 0x7fff)) 191 self.assertEqual(-1, testNaked("I16", -1)) 192 self.assertEqual(-15000, testNaked("I16", -15000)) 193 self.assertEqual(-0x7fff, testNaked("I16", -0x7fff)) 194 self.assertEqual(32767, testNaked("I16", 32767)) 195 self.assertEqual(-32768, testNaked("I16", -32768)) 196 197 self.assertEqual(0, testField("I16", 0)) 198 self.assertEqual(1, testField("I16", 1)) 199 self.assertEqual(7, testField("I16", 7)) 200 self.assertEqual(150, testField("I16", 150)) 201 self.assertEqual(15000, testField("I16", 15000)) 202 self.assertEqual(0x7fff, testField("I16", 0x7fff)) 203 self.assertEqual(-1, testField("I16", -1)) 204 self.assertEqual(-7, testField("I16", -7)) 205 self.assertEqual(-150, testField("I16", -150)) 206 self.assertEqual(-15000, testField("I16", -15000)) 207 self.assertEqual(-0xfff, testField("I16", -0xfff)) 208 209 self.assertEqual(0, testNaked("I32", 0)) 210 self.assertEqual(1, testNaked("I32", 1)) 211 self.assertEqual(15000, testNaked("I32", 15000)) 212 self.assertEqual(0xffff, testNaked("I32", 0xffff)) 213 self.assertEqual(-1, testNaked("I32", -1)) 214 self.assertEqual(-15000, testNaked("I32", -15000)) 215 self.assertEqual(-0xffff, testNaked("I32", -0xffff)) 216 self.assertEqual(2147483647, testNaked("I32", 2147483647)) 217 self.assertEqual(-2147483647, testNaked("I32", -2147483647)) 218 219 self.assertEqual(0, testField("I32", 0)) 220 self.assertEqual(1, testField("I32", 1)) 221 self.assertEqual(7, testField("I32", 7)) 222 self.assertEqual(150, testField("I32", 150)) 223 self.assertEqual(15000, testField("I32", 15000)) 224 self.assertEqual(31337, testField("I32", 31337)) 225 self.assertEqual(0xffff, testField("I32", 0xffff)) 226 self.assertEqual(0xffffff, testField("I32", 0xffffff)) 227 self.assertEqual(-1, testField("I32", -1)) 228 self.assertEqual(-7, testField("I32", -7)) 229 self.assertEqual(-150, testField("I32", -150)) 230 self.assertEqual(-15000, testField("I32", -15000)) 231 self.assertEqual(-0xffff, testField("I32", -0xffff)) 232 self.assertEqual(-0xffffff, testField("I32", -0xffffff)) 233 234 self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807)) 235 self.assertEqual(-9223372036854775807, testNaked("I64", -9223372036854775807)) 236 self.assertEqual(-0, testNaked("I64", 0)) 237 238 self.assertEqual(True, testNaked("Bool", True)) 239 self.assertEqual(3.14159261, testNaked("Double", 3.14159261)) 240 self.assertEqual("hello thrift", testNaked("String", "hello thrift")) 241 self.assertEqual(True, testField('Bool', True)) 242 self.assertEqual(3.1415926, testNaked("Double", 3.1415926)) 243 self.assertEqual("hello thrift", testNaked("String", "hello thrift")) 244 245 TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4} 246 test_data = [("short message name", TMessageType['T_CALL'], 0), 247 ("1", TMessageType['T_REPLY'], 12345), 248 ("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16), 249 ("one way push", TMessageType['T_ONEWAY'], 12), 250 ("Janky", TMessageType['T_CALL'], 0)] 251 252 for dt in test_data: 253 result = testMessage(dt) 254 self.assertEqual(result[0], dt[0]) 255 self.assertEqual(result[1], dt[1]) 256 self.assertEqual(result[2], dt[2]) 257 258 except Exception as e: 259 print("Assertion fail") 260 raise e 261 262 263if __name__ == '__main__': 264 unittest.main() 265