1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #ifndef _THRIFT_PROCESSOR_TEST_HANDLERS_H_
20 #define _THRIFT_PROCESSOR_TEST_HANDLERS_H_ 1
21 
22 #include "EventLog.h"
23 #include "gen-cpp/ParentService.h"
24 #include "gen-cpp/ChildService.h"
25 
26 namespace apache {
27 namespace thrift {
28 namespace test {
29 
30 class ParentHandler : virtual public ParentServiceIf {
31 public:
ParentHandler(const std::shared_ptr<EventLog> & log)32   ParentHandler(const std::shared_ptr<EventLog>& log)
33     : triggerMonitor(&mutex_), generation_(0), wait_(false), log_(log) {}
34 
incrementGeneration()35   int32_t incrementGeneration() override {
36     concurrency::Guard g(mutex_);
37     log_->append(EventLog::ET_CALL_INCREMENT_GENERATION, 0, 0);
38     return ++generation_;
39   }
40 
getGeneration()41   int32_t getGeneration() override {
42     concurrency::Guard g(mutex_);
43     log_->append(EventLog::ET_CALL_GET_GENERATION, 0, 0);
44     return generation_;
45   }
46 
addString(const std::string & s)47   void addString(const std::string& s) override {
48     concurrency::Guard g(mutex_);
49     log_->append(EventLog::ET_CALL_ADD_STRING, 0, 0);
50     strings_.push_back(s);
51   }
52 
getStrings(std::vector<std::string> & _return)53   void getStrings(std::vector<std::string>& _return) override {
54     concurrency::Guard g(mutex_);
55     log_->append(EventLog::ET_CALL_GET_STRINGS, 0, 0);
56     _return = strings_;
57   }
58 
getDataWait(std::string & _return,const int32_t length)59   void getDataWait(std::string& _return, const int32_t length) override {
60     concurrency::Guard g(mutex_);
61     log_->append(EventLog::ET_CALL_GET_DATA_WAIT, 0, 0);
62 
63     blockUntilTriggered();
64 
65     _return.append(length, 'a');
66   }
67 
onewayWait()68   void onewayWait() override {
69     concurrency::Guard g(mutex_);
70     log_->append(EventLog::ET_CALL_ONEWAY_WAIT, 0, 0);
71 
72     blockUntilTriggered();
73   }
74 
exceptionWait(const std::string & message)75   void exceptionWait(const std::string& message) override {
76     concurrency::Guard g(mutex_);
77     log_->append(EventLog::ET_CALL_EXCEPTION_WAIT, 0, 0);
78 
79     blockUntilTriggered();
80 
81     MyError e;
82     e.message = message;
83     throw e;
84   }
85 
unexpectedExceptionWait(const std::string & message)86   void unexpectedExceptionWait(const std::string& message) override {
87     concurrency::Guard g(mutex_);
88     log_->append(EventLog::ET_CALL_UNEXPECTED_EXCEPTION_WAIT, 0, 0);
89 
90     blockUntilTriggered();
91 
92     MyError e;
93     e.message = message;
94     throw e;
95   }
96 
97   /**
98    * After prepareTriggeredCall() is invoked, calls to any of the *Wait()
99    * functions won't return until triggerPendingCalls() is invoked
100    *
101    * This has to be a separate function invoked by the main test thread
102    * in order to to avoid race conditions.
103    */
prepareTriggeredCall()104   void prepareTriggeredCall() {
105     concurrency::Guard g(mutex_);
106     wait_ = true;
107   }
108 
109   /**
110    * Wake up all calls waiting in blockUntilTriggered()
111    */
triggerPendingCalls()112   void triggerPendingCalls() {
113     concurrency::Guard g(mutex_);
114     wait_ = false;
115     triggerMonitor.notifyAll();
116   }
117 
118 protected:
119   /**
120    * blockUntilTriggered() won't return until triggerPendingCalls() is invoked
121    * in another thread.
122    *
123    * This should only be called when already holding mutex_.
124    */
blockUntilTriggered()125   void blockUntilTriggered() {
126     while (wait_) {
127       triggerMonitor.waitForever();
128     }
129 
130     // Log an event when we return
131     log_->append(EventLog::ET_WAIT_RETURN, 0, 0);
132   }
133 
134   concurrency::Mutex mutex_;
135   concurrency::Monitor triggerMonitor;
136   int32_t generation_;
137   bool wait_;
138   std::vector<std::string> strings_;
139   std::shared_ptr<EventLog> log_;
140 };
141 
142 #ifdef _MSC_VER
143   #pragma warning( push )
144   #pragma warning (disable : 4250 ) //inheriting methods via dominance
145 #endif
146 
147 class ChildHandler : public ParentHandler, virtual public ChildServiceIf {
148 public:
ChildHandler(const std::shared_ptr<EventLog> & log)149   ChildHandler(const std::shared_ptr<EventLog>& log) : ParentHandler(log), value_(0) {}
150 
setValue(const int32_t value)151   int32_t setValue(const int32_t value) override {
152     concurrency::Guard g(mutex_);
153     log_->append(EventLog::ET_CALL_SET_VALUE, 0, 0);
154 
155     int32_t oldValue = value_;
156     value_ = value;
157     return oldValue;
158   }
159 
getValue()160   int32_t getValue() override {
161     concurrency::Guard g(mutex_);
162     log_->append(EventLog::ET_CALL_GET_VALUE, 0, 0);
163 
164     return value_;
165   }
166 
167 protected:
168   int32_t value_;
169 };
170 
171 #ifdef _MSC_VER
172   #pragma warning( pop )
173 #endif
174 
175 struct ConnContext {
176 public:
ConnContextConnContext177   ConnContext(std::shared_ptr<protocol::TProtocol> in,
178               std::shared_ptr<protocol::TProtocol> out,
179               uint32_t id)
180     : input(in), output(out), id(id) {}
181 
182   std::shared_ptr<protocol::TProtocol> input;
183   std::shared_ptr<protocol::TProtocol> output;
184   uint32_t id;
185 };
186 
187 struct CallContext {
188 public:
CallContextCallContext189   CallContext(ConnContext* context, uint32_t id, const std::string& name)
190     : connContext(context), name(name), id(id) {}
191 
192   ConnContext* connContext;
193   std::string name;
194   uint32_t id;
195 };
196 
197 class ServerEventHandler : public server::TServerEventHandler {
198 public:
ServerEventHandler(const std::shared_ptr<EventLog> & log)199   ServerEventHandler(const std::shared_ptr<EventLog>& log) : nextId_(1), log_(log) {}
200 
preServe()201   void preServe() override {}
202 
createContext(std::shared_ptr<protocol::TProtocol> input,std::shared_ptr<protocol::TProtocol> output)203   void* createContext(std::shared_ptr<protocol::TProtocol> input,
204                               std::shared_ptr<protocol::TProtocol> output) override {
205     ConnContext* context = new ConnContext(input, output, nextId_);
206     ++nextId_;
207     log_->append(EventLog::ET_CONN_CREATED, context->id, 0);
208     return context;
209   }
210 
deleteContext(void * serverContext,std::shared_ptr<protocol::TProtocol> input,std::shared_ptr<protocol::TProtocol> output)211   void deleteContext(void* serverContext,
212                              std::shared_ptr<protocol::TProtocol> input,
213                              std::shared_ptr<protocol::TProtocol> output) override {
214     auto* context = reinterpret_cast<ConnContext*>(serverContext);
215 
216     if (input != context->input) {
217       abort();
218     }
219     if (output != context->output) {
220       abort();
221     }
222 
223     log_->append(EventLog::ET_CONN_DESTROYED, context->id, 0);
224 
225     delete context;
226   }
227 
processContext(void * serverContext,std::shared_ptr<transport::TTransport> transport)228   void processContext(void* serverContext,
229                               std::shared_ptr<transport::TTransport> transport) override {
230 // TODO: We currently don't test the behavior of the processContext()
231 // calls.  The various server implementations call processContext() at
232 // slightly different times, and it is too annoying to try and account for
233 // their various differences.
234 //
235 // TThreadedServer, TThreadPoolServer, and TSimpleServer usually wait until
236 // they see the first byte of a request before calling processContext().
237 // However, they don't wait for the first byte of the very first request,
238 // and instead immediately call processContext() before any data is
239 // received.
240 //
241 // TNonblockingServer always waits until receiving the full request before
242 // calling processContext().
243 #if 0
244     ConnContext* context = reinterpret_cast<ConnContext*>(serverContext);
245     log_->append(EventLog::ET_PROCESS, context->id, 0);
246 #else
247     THRIFT_UNUSED_VARIABLE(serverContext);
248     THRIFT_UNUSED_VARIABLE(transport);
249 #endif
250   }
251 
252 protected:
253   uint32_t nextId_;
254   std::shared_ptr<EventLog> log_;
255 };
256 
257 class ProcessorEventHandler : public TProcessorEventHandler {
258 public:
ProcessorEventHandler(const std::shared_ptr<EventLog> & log)259   ProcessorEventHandler(const std::shared_ptr<EventLog>& log) : nextId_(1), log_(log) {}
260 
getContext(const char * fnName,void * serverContext)261   void* getContext(const char* fnName, void* serverContext) override {
262     auto* connContext = reinterpret_cast<ConnContext*>(serverContext);
263 
264     CallContext* context = new CallContext(connContext, nextId_, fnName);
265     ++nextId_;
266 
267     log_->append(EventLog::ET_CALL_STARTED, connContext->id, context->id, fnName);
268     return context;
269   }
270 
freeContext(void * ctx,const char * fnName)271   void freeContext(void* ctx, const char* fnName) override {
272     auto* context = reinterpret_cast<CallContext*>(ctx);
273     checkName(context, fnName);
274     log_->append(EventLog::ET_CALL_FINISHED, context->connContext->id, context->id, fnName);
275     delete context;
276   }
277 
preRead(void * ctx,const char * fnName)278   void preRead(void* ctx, const char* fnName) override {
279     auto* context = reinterpret_cast<CallContext*>(ctx);
280     checkName(context, fnName);
281     log_->append(EventLog::ET_PRE_READ, context->connContext->id, context->id, fnName);
282   }
283 
postRead(void * ctx,const char * fnName,uint32_t bytes)284   void postRead(void* ctx, const char* fnName, uint32_t bytes) override {
285     THRIFT_UNUSED_VARIABLE(bytes);
286     auto* context = reinterpret_cast<CallContext*>(ctx);
287     checkName(context, fnName);
288     log_->append(EventLog::ET_POST_READ, context->connContext->id, context->id, fnName);
289   }
290 
preWrite(void * ctx,const char * fnName)291   void preWrite(void* ctx, const char* fnName) override {
292     auto* context = reinterpret_cast<CallContext*>(ctx);
293     checkName(context, fnName);
294     log_->append(EventLog::ET_PRE_WRITE, context->connContext->id, context->id, fnName);
295   }
296 
postWrite(void * ctx,const char * fnName,uint32_t bytes)297   void postWrite(void* ctx, const char* fnName, uint32_t bytes) override {
298     THRIFT_UNUSED_VARIABLE(bytes);
299     auto* context = reinterpret_cast<CallContext*>(ctx);
300     checkName(context, fnName);
301     log_->append(EventLog::ET_POST_WRITE, context->connContext->id, context->id, fnName);
302   }
303 
asyncComplete(void * ctx,const char * fnName)304   void asyncComplete(void* ctx, const char* fnName) override {
305     auto* context = reinterpret_cast<CallContext*>(ctx);
306     checkName(context, fnName);
307     log_->append(EventLog::ET_ASYNC_COMPLETE, context->connContext->id, context->id, fnName);
308   }
309 
handlerError(void * ctx,const char * fnName)310   void handlerError(void* ctx, const char* fnName) override {
311     auto* context = reinterpret_cast<CallContext*>(ctx);
312     checkName(context, fnName);
313     log_->append(EventLog::ET_HANDLER_ERROR, context->connContext->id, context->id, fnName);
314   }
315 
316 protected:
checkName(const CallContext * context,const char * fnName)317   void checkName(const CallContext* context, const char* fnName) {
318     // Note: we can't use BOOST_CHECK_EQUAL here, since the handler runs in a
319     // different thread from the test functions.  Just abort if the names are
320     // different
321     if (context->name != fnName) {
322       fprintf(stderr,
323               "call context name mismatch: \"%s\" != \"%s\"\n",
324               context->name.c_str(),
325               fnName);
326       fflush(stderr);
327       abort();
328     }
329   }
330 
331   uint32_t nextId_;
332   std::shared_ptr<EventLog> log_;
333 };
334 }
335 }
336 } // apache::thrift::test
337 
338 #endif // _THRIFT_PROCESSOR_TEST_HANDLERS_H_
339