1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <boost/test/unit_test.hpp>
21 #include <boost/thread.hpp>
22 #include <iostream>
23 #include <climits>
24 #include <vector>
25 #include <thrift/concurrency/Monitor.h>
26 #include <thrift/protocol/TBinaryProtocol.h>
27 #include <thrift/protocol/TJSONProtocol.h>
28 #include <thrift/server/TThreadedServer.h>
29 #include <thrift/transport/THttpServer.h>
30 #include <thrift/transport/THttpClient.h>
31 #include <thrift/transport/TServerSocket.h>
32 #include <thrift/transport/TSocket.h>
33 #include <memory>
34 #include <thrift/transport/TBufferTransports.h>
35 #include "gen-cpp/OneWayService.h"
36 
37 BOOST_AUTO_TEST_SUITE(OneWayHTTPTest)
38 
39 using namespace apache::thrift;
40 using apache::thrift::protocol::TProtocol;
41 using apache::thrift::protocol::TBinaryProtocol;
42 using apache::thrift::protocol::TBinaryProtocolFactory;
43 using apache::thrift::protocol::TJSONProtocol;
44 using apache::thrift::protocol::TJSONProtocolFactory;
45 using apache::thrift::server::TThreadedServer;
46 using apache::thrift::server::TServerEventHandler;
47 using apache::thrift::transport::TTransport;
48 using apache::thrift::transport::THttpServer;
49 using apache::thrift::transport::THttpServerTransportFactory;
50 using apache::thrift::transport::THttpClient;
51 using apache::thrift::transport::TBufferedTransport;
52 using apache::thrift::transport::TBufferedTransportFactory;
53 using apache::thrift::transport::TMemoryBuffer;
54 using apache::thrift::transport::TServerSocket;
55 using apache::thrift::transport::TSocket;
56 using apache::thrift::transport::TTransportException;
57 using std::shared_ptr;
58 using std::cout;
59 using std::cerr;
60 using std::endl;
61 using std::string;
62 namespace utf = boost::unit_test;
63 
64 // Define this env var to enable some logging (in case you need to debug)
65 #undef ENABLE_STDERR_LOGGING
66 
67 class OneWayServiceHandler : public onewaytest::OneWayServiceIf {
68 public:
69   OneWayServiceHandler() = default;
70 
roundTripRPC()71   void roundTripRPC() override {
72 #ifdef ENABLE_STDERR_LOGGING
73     cerr << "roundTripRPC()" << endl;
74 #endif
75   }
oneWayRPC()76   void oneWayRPC() override {
77 #ifdef ENABLE_STDERR_LOGGING
78     cerr << "oneWayRPC()" << std::endl ;
79 #endif
80  }
81 };
82 
83 class OneWayServiceCloneFactory : virtual public onewaytest::OneWayServiceIfFactory {
84  public:
85   ~OneWayServiceCloneFactory() override = default;
getHandler(const::apache::thrift::TConnectionInfo & connInfo)86   onewaytest::OneWayServiceIf* getHandler(const ::apache::thrift::TConnectionInfo& connInfo) override
87   {
88     (void)connInfo ;
89     return new OneWayServiceHandler;
90   }
releaseHandler(onewaytest::OneWayServiceIf * handler)91   void releaseHandler( onewaytest::OneWayServiceIf* handler) override {
92     delete handler;
93   }
94 };
95 
96 class RPC0ThreadClass {
97 public:
RPC0ThreadClass(TThreadedServer & server)98   RPC0ThreadClass(TThreadedServer& server) : server_(server) { } // Constructor
99 ~RPC0ThreadClass() = default; // Destructor
100 
Run()101 void Run() {
102   server_.serve() ;
103 }
104  TThreadedServer& server_ ;
105 } ;
106 
107 using apache::thrift::concurrency::Monitor;
108 using apache::thrift::concurrency::Mutex;
109 using apache::thrift::concurrency::Synchronized;
110 
111 // copied from IntegrationTest
112 class TServerReadyEventHandler : public TServerEventHandler, public Monitor {
113 public:
TServerReadyEventHandler()114   TServerReadyEventHandler() : isListening_(false), accepted_(0) {}
115   ~TServerReadyEventHandler() override = default;
preServe()116   void preServe() override {
117     Synchronized sync(*this);
118     isListening_ = true;
119     notify();
120   }
createContext(shared_ptr<TProtocol> input,shared_ptr<TProtocol> output)121   void* createContext(shared_ptr<TProtocol> input,
122                               shared_ptr<TProtocol> output) override {
123     Synchronized sync(*this);
124     ++accepted_;
125     notify();
126 
127     (void)input;
128     (void)output;
129     return nullptr;
130   }
isListening() const131   bool isListening() const { return isListening_; }
acceptedCount() const132   uint64_t acceptedCount() const { return accepted_; }
133 
134 private:
135   bool isListening_;
136   uint64_t accepted_;
137 };
138 
139 class TBlockableBufferedTransport : public TBufferedTransport {
140  public:
TBlockableBufferedTransport(std::shared_ptr<TTransport> transport)141   TBlockableBufferedTransport(std::shared_ptr<TTransport> transport)
142     : TBufferedTransport(transport, 10240),
143     blocked_(false) {
144   }
145 
write_buffer_length()146   uint32_t write_buffer_length() {
147     auto have_bytes = static_cast<uint32_t>(wBase_ - wBuf_.get());
148     return have_bytes ;
149   }
150 
block()151   void block() {
152     blocked_ = true ;
153 #ifdef ENABLE_STDERR_LOGGING
154     cerr << "block flushing\n" ;
155 #endif
156  }
unblock()157   void unblock() {
158     blocked_ = false ;
159 #ifdef ENABLE_STDERR_LOGGING
160     cerr << "unblock flushing, buffer is\n<<" << std::string((char *)wBuf_.get(), write_buffer_length()) << ">>\n" ;
161 #endif
162  }
163 
flush()164   void flush() override {
165     if (blocked_) {
166 #ifdef ENABLE_STDERR_LOGGING
167       cerr << "flush was blocked\n" ;
168 #endif
169       return ;
170     }
171     TBufferedTransport::flush() ;
172   }
173 
174   bool blocked_ ;
175 } ;
176 
BOOST_AUTO_TEST_CASE(JSON_BufferedHTTP)177 BOOST_AUTO_TEST_CASE( JSON_BufferedHTTP )
178 {
179   std::shared_ptr<TServerSocket> ss = std::make_shared<TServerSocket>(0) ;
180   TThreadedServer server(
181     std::make_shared<onewaytest::OneWayServiceProcessorFactory>(std::make_shared<OneWayServiceCloneFactory>()),
182     ss, //port
183     std::make_shared<THttpServerTransportFactory>(),
184     std::make_shared<TJSONProtocolFactory>());
185 
186   std::shared_ptr<TServerReadyEventHandler> pEventHandler(new TServerReadyEventHandler) ;
187   server.setServerEventHandler(pEventHandler);
188 
189 #ifdef ENABLE_STDERR_LOGGING
190   cerr << "Starting the server...\n";
191 #endif
192   RPC0ThreadClass t(server) ;
193   boost::thread thread(&RPC0ThreadClass::Run, &t);
194 
195   {
196     Synchronized sync(*(pEventHandler.get()));
197     while (!pEventHandler->isListening()) {
198       pEventHandler->wait();
199     }
200   }
201 
202   int port = ss->getPort() ;
203 #ifdef ENABLE_STDERR_LOGGING
204   cerr << "port " << port << endl ;
205 #endif
206 
207   {
208     std::shared_ptr<TSocket> socket(new TSocket("localhost", port));
209     socket->setRecvTimeout(10000) ; // 1000msec should be enough
210     std::shared_ptr<TBlockableBufferedTransport> blockable_transport(new TBlockableBufferedTransport(socket));
211     std::shared_ptr<TTransport> transport(new THttpClient(blockable_transport, "localhost", "/service"));
212     std::shared_ptr<TProtocol> protocol(new TJSONProtocol(transport));
213     onewaytest::OneWayServiceClient client(protocol);
214 
215 
216     transport->open();
217     client.roundTripRPC();
218     blockable_transport->block() ;
219     uint32_t size0 = blockable_transport->write_buffer_length() ;
220     client.send_oneWayRPC() ;
221     uint32_t size1 = blockable_transport->write_buffer_length() ;
222     client.send_oneWayRPC() ;
223     uint32_t size2 = blockable_transport->write_buffer_length() ;
224     BOOST_CHECK((size1 - size0) == (size2 - size1)) ;
225     blockable_transport->unblock() ;
226     client.send_roundTripRPC() ;
227     blockable_transport->flush() ;
228     try {
229       client.recv_roundTripRPC() ;
230     } catch (const TTransportException &e) {
231       BOOST_ERROR( "we should not get a transport exception -- this means we failed: " + std::string(e.what()) ) ;
232     }
233     transport->close();
234   }
235   server.stop();
236   thread.join() ;
237 #ifdef ENABLE_STDERR_LOGGING
238   cerr << "finished.\n";
239 #endif
240 }
241 
242 BOOST_AUTO_TEST_SUITE_END()
243