1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import unittest
21
22import _import_local_thrift  # noqa
23from thrift.protocol.TBinaryProtocol import TBinaryProtocol
24from thrift.transport import TTransport
25
26
27def testNaked(type, data):
28    buf = TTransport.TMemoryBuffer()
29    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
30    protocol = TBinaryProtocol(transport)
31    if type.capitalize() == 'Byte':
32        protocol.writeByte(data)
33
34    if type.capitalize() == 'I16':
35        protocol.writeI16(data)
36
37    if type.capitalize() == 'I32':
38        protocol.writeI32(data)
39
40    if type.capitalize() == 'I64':
41        protocol.writeI64(data)
42
43    if type.capitalize() == 'String':
44        protocol.writeString(data)
45
46    if type.capitalize() == 'Double':
47        protocol.writeDouble(data)
48
49    if type.capitalize() == 'Binary':
50        protocol.writeBinary(data)
51
52    if type.capitalize() == 'Bool':
53        protocol.writeBool(data)
54
55    transport.flush()
56    data_r = buf.getvalue()
57    buf = TTransport.TMemoryBuffer(data_r)
58    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
59    protocol = TBinaryProtocol(transport)
60    if type.capitalize() == 'Byte':
61        return protocol.readByte()
62
63    if type.capitalize() == 'I16':
64        return protocol.readI16()
65
66    if type.capitalize() == 'I32':
67        return protocol.readI32()
68
69    if type.capitalize() == 'I64':
70        return protocol.readI64()
71
72    if type.capitalize() == 'String':
73        return protocol.readString()
74
75    if type.capitalize() == 'Double':
76        return protocol.readDouble()
77
78    if type.capitalize() == 'Binary':
79        return protocol.readBinary()
80
81    if type.capitalize() == 'Bool':
82        return protocol.readBool()
83
84
85def testField(type, data):
86    TType = {"Bool": 2, "Byte": 3, "Binary": 5, "I16": 6, "I32": 8, "I64": 10, "Double": 11, "String": 12}
87    buf = TTransport.TMemoryBuffer()
88    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
89    protocol = TBinaryProtocol(transport)
90    protocol.writeStructBegin('struct')
91    protocol.writeFieldBegin("field", TType[type.capitalize()], 10)
92    if type.capitalize() == 'Byte':
93        protocol.writeByte(data)
94
95    if type.capitalize() == 'I16':
96        protocol.writeI16(data)
97
98    if type.capitalize() == 'I32':
99        protocol.writeI32(data)
100
101    if type.capitalize() == 'I64':
102        protocol.writeI64(data)
103
104    if type.capitalize() == 'String':
105        protocol.writeString(data)
106
107    if type.capitalize() == 'Double':
108        protocol.writeDouble(data)
109
110    if type.capitalize() == 'Binary':
111        protocol.writeBinary(data)
112
113    if type.capitalize() == 'Bool':
114        protocol.writeBool(data)
115
116    protocol.writeFieldEnd()
117    protocol.writeStructEnd()
118
119    transport.flush()
120    data_r = buf.getvalue()
121
122    buf = TTransport.TMemoryBuffer(data_r)
123    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
124    protocol = TBinaryProtocol(transport)
125    protocol.readStructBegin()
126    protocol.readFieldBegin()
127    if type.capitalize() == 'Byte':
128        return protocol.readByte()
129
130    if type.capitalize() == 'I16':
131        return protocol.readI16()
132
133    if type.capitalize() == 'I32':
134        return protocol.readI32()
135
136    if type.capitalize() == 'I64':
137        return protocol.readI64()
138
139    if type.capitalize() == 'String':
140        return protocol.readString()
141
142    if type.capitalize() == 'Double':
143        return protocol.readDouble()
144
145    if type.capitalize() == 'Binary':
146        return protocol.readBinary()
147
148    if type.capitalize() == 'Bool':
149        return protocol.readBool()
150
151    protocol.readFieldEnd()
152    protocol.readStructEnd()
153
154
155def testMessage(data):
156    message = {}
157    message['name'] = data[0]
158    message['type'] = data[1]
159    message['seqid'] = data[2]
160
161    buf = TTransport.TMemoryBuffer()
162    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
163    protocol = TBinaryProtocol(transport)
164    protocol.writeMessageBegin(message['name'], message['type'], message['seqid'])
165    protocol.writeMessageEnd()
166
167    transport.flush()
168    data_r = buf.getvalue()
169
170    buf = TTransport.TMemoryBuffer(data_r)
171    transport = TTransport.TBufferedTransportFactory().getTransport(buf)
172    protocol = TBinaryProtocol(transport)
173    result = protocol.readMessageBegin()
174    protocol.readMessageEnd()
175    return result
176
177
178class TestTBinaryProtocol(unittest.TestCase):
179
180    def test_TBinaryProtocol_write_read(self):
181        try:
182            testNaked('Byte', 123)
183            for i in range(0, 128):
184                self.assertEqual(i, testField('Byte', i))
185                self.assertEqual(-i, testField('Byte', -i))
186
187            self.assertEqual(0, testNaked("I16", 0))
188            self.assertEqual(1, testNaked("I16", 1))
189            self.assertEqual(15000, testNaked("I16", 15000))
190            self.assertEqual(0x7fff, testNaked("I16", 0x7fff))
191            self.assertEqual(-1, testNaked("I16", -1))
192            self.assertEqual(-15000, testNaked("I16", -15000))
193            self.assertEqual(-0x7fff, testNaked("I16", -0x7fff))
194            self.assertEqual(32767, testNaked("I16", 32767))
195            self.assertEqual(-32768, testNaked("I16", -32768))
196
197            self.assertEqual(0, testField("I16", 0))
198            self.assertEqual(1, testField("I16", 1))
199            self.assertEqual(7, testField("I16", 7))
200            self.assertEqual(150, testField("I16", 150))
201            self.assertEqual(15000, testField("I16", 15000))
202            self.assertEqual(0x7fff, testField("I16", 0x7fff))
203            self.assertEqual(-1, testField("I16", -1))
204            self.assertEqual(-7, testField("I16", -7))
205            self.assertEqual(-150, testField("I16", -150))
206            self.assertEqual(-15000, testField("I16", -15000))
207            self.assertEqual(-0xfff, testField("I16", -0xfff))
208
209            self.assertEqual(0, testNaked("I32", 0))
210            self.assertEqual(1, testNaked("I32", 1))
211            self.assertEqual(15000, testNaked("I32", 15000))
212            self.assertEqual(0xffff, testNaked("I32", 0xffff))
213            self.assertEqual(-1, testNaked("I32", -1))
214            self.assertEqual(-15000, testNaked("I32", -15000))
215            self.assertEqual(-0xffff, testNaked("I32", -0xffff))
216            self.assertEqual(2147483647, testNaked("I32", 2147483647))
217            self.assertEqual(-2147483647, testNaked("I32", -2147483647))
218
219            self.assertEqual(0, testField("I32", 0))
220            self.assertEqual(1, testField("I32", 1))
221            self.assertEqual(7, testField("I32", 7))
222            self.assertEqual(150, testField("I32", 150))
223            self.assertEqual(15000, testField("I32", 15000))
224            self.assertEqual(31337, testField("I32", 31337))
225            self.assertEqual(0xffff, testField("I32", 0xffff))
226            self.assertEqual(0xffffff, testField("I32", 0xffffff))
227            self.assertEqual(-1, testField("I32", -1))
228            self.assertEqual(-7, testField("I32", -7))
229            self.assertEqual(-150, testField("I32", -150))
230            self.assertEqual(-15000, testField("I32", -15000))
231            self.assertEqual(-0xffff, testField("I32", -0xffff))
232            self.assertEqual(-0xffffff, testField("I32", -0xffffff))
233
234            self.assertEqual(9223372036854775807, testNaked("I64", 9223372036854775807))
235            self.assertEqual(-9223372036854775807, testNaked("I64", -9223372036854775807))
236            self.assertEqual(-0, testNaked("I64", 0))
237
238            self.assertEqual(True, testNaked("Bool", True))
239            self.assertEqual(3.14159261, testNaked("Double", 3.14159261))
240            self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
241            self.assertEqual(True, testField('Bool', True))
242            self.assertEqual(3.1415926, testNaked("Double", 3.1415926))
243            self.assertEqual("hello thrift", testNaked("String", "hello thrift"))
244
245            TMessageType = {"T_CALL": 1, "T_REPLY": 2, "T_EXCEPTION": 3, "T_ONEWAY": 4}
246            test_data = [("short message name", TMessageType['T_CALL'], 0),
247                         ("1", TMessageType['T_REPLY'], 12345),
248                         ("loooooooooooooooooooooooooooooooooong", TMessageType['T_EXCEPTION'], 1 << 16),
249                         ("one way push", TMessageType['T_ONEWAY'], 12),
250                         ("Janky", TMessageType['T_CALL'], 0)]
251
252            for dt in test_data:
253                result = testMessage(dt)
254                self.assertEqual(result[0], dt[0])
255                self.assertEqual(result[1], dt[1])
256                self.assertEqual(result[2], dt[2])
257
258        except Exception as e:
259            print("Assertion fail")
260            raise e
261
262
263if __name__ == '__main__':
264    unittest.main()
265