1#!/usr/bin/env python 2 3# 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20# 21 22from ThriftTest.ttypes import Xtruct 23from thrift.transport import TTransport 24from thrift.protocol import TBinaryProtocol 25from thrift.protocol import TCompactProtocol 26import unittest 27 28 29class TestEof(unittest.TestCase): 30 31 def make_data(self, pfactory=None): 32 trans = TTransport.TMemoryBuffer() 33 if pfactory: 34 prot = pfactory.getProtocol(trans) 35 else: 36 prot = TBinaryProtocol.TBinaryProtocol(trans) 37 38 x = Xtruct() 39 x.string_thing = "Zero" 40 x.byte_thing = 0 41 42 x.write(prot) 43 44 x = Xtruct() 45 x.string_thing = "One" 46 x.byte_thing = 1 47 48 x.write(prot) 49 50 return trans.getvalue() 51 52 def testTransportReadAll(self): 53 """Test that readAll on any type of transport throws an EOFError""" 54 trans = TTransport.TMemoryBuffer(self.make_data()) 55 trans.readAll(1) 56 57 try: 58 trans.readAll(10000) 59 except EOFError: 60 return 61 62 self.fail("Should have gotten EOFError") 63 64 def eofTestHelper(self, pfactory): 65 trans = TTransport.TMemoryBuffer(self.make_data(pfactory)) 66 prot = pfactory.getProtocol(trans) 67 68 x = Xtruct() 69 x.read(prot) 70 self.assertEqual(x.string_thing, "Zero") 71 self.assertEqual(x.byte_thing, 0) 72 73 x = Xtruct() 74 x.read(prot) 75 self.assertEqual(x.string_thing, "One") 76 self.assertEqual(x.byte_thing, 1) 77 78 try: 79 x = Xtruct() 80 x.read(prot) 81 except EOFError: 82 return 83 84 self.fail("Should have gotten EOFError") 85 86 def eofTestHelperStress(self, pfactory): 87 """Test the ability of TBinaryProtocol to deal with the removal of every byte in the file""" 88 # TODO: we should make sure this covers more of the code paths 89 90 data = self.make_data(pfactory) 91 for i in range(0, len(data) + 1): 92 trans = TTransport.TMemoryBuffer(data[0:i]) 93 prot = pfactory.getProtocol(trans) 94 try: 95 x = Xtruct() 96 x.read(prot) 97 x.read(prot) 98 x.read(prot) 99 except EOFError: 100 continue 101 self.fail("Should have gotten an EOFError") 102 103 def testBinaryProtocolEof(self): 104 """Test that TBinaryProtocol throws an EOFError when it reaches the end of the stream""" 105 self.eofTestHelper(TBinaryProtocol.TBinaryProtocolFactory()) 106 self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolFactory()) 107 108 def testBinaryProtocolAcceleratedBinaryEof(self): 109 """Test that TBinaryProtocolAccelerated throws an EOFError when it reaches the end of the stream""" 110 self.eofTestHelper(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)) 111 self.eofTestHelperStress(TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)) 112 113 def testCompactProtocolEof(self): 114 """Test that TCompactProtocol throws an EOFError when it reaches the end of the stream""" 115 self.eofTestHelper(TCompactProtocol.TCompactProtocolFactory()) 116 self.eofTestHelperStress(TCompactProtocol.TCompactProtocolFactory()) 117 118 def testCompactProtocolAcceleratedCompactEof(self): 119 """Test that TCompactProtocolAccelerated throws an EOFError when it reaches the end of the stream""" 120 self.eofTestHelper(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)) 121 self.eofTestHelperStress(TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)) 122 123 124def suite(): 125 suite = unittest.TestSuite() 126 loader = unittest.TestLoader() 127 suite.addTest(loader.loadTestsFromTestCase(TestEof)) 128 return suite 129 130 131if __name__ == "__main__": 132 unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2)) 133