1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <thrift/concurrency/ThreadManager.h>
21 #include <thrift/concurrency/ThreadFactory.h>
22 #include <thrift/concurrency/Monitor.h>
23 #include <thrift/concurrency/Mutex.h>
24 #include <thrift/protocol/TBinaryProtocol.h>
25 #include <thrift/server/TSimpleServer.h>
26 #include <thrift/server/TThreadPoolServer.h>
27 #include <thrift/server/TThreadedServer.h>
28 #include <thrift/transport/TServerSocket.h>
29 #include <thrift/transport/TSocket.h>
30 #include <thrift/transport/TTransportUtils.h>
31 #include <thrift/transport/TFileTransport.h>
32 #include <thrift/TLogging.h>
33 
34 #include "Service.h"
35 #include <iostream>
36 #include <set>
37 #include <stdexcept>
38 #include <sstream>
39 #include <map>
40 #if _WIN32
41 #include <thrift/windows/TWinsockSingleton.h>
42 #endif
43 
44 using namespace std;
45 
46 using namespace apache::thrift;
47 using namespace apache::thrift::async;
48 using namespace apache::thrift::protocol;
49 using namespace apache::thrift::transport;
50 using namespace apache::thrift::server;
51 using namespace apache::thrift::concurrency;
52 
53 using namespace test::stress;
54 
55 struct eqstr {
operator ()eqstr56   bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) == 0; }
57 };
58 
59 struct ltstr {
operator ()ltstr60   bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) < 0; }
61 };
62 
63 // typedef hash_map<const char*, int, hash<const char*>, eqstr> count_map;
64 typedef map<const char*, int, ltstr> count_map;
65 
66 class Server : public ServiceIf {
67 public:
68   Server() = default;
69 
count(const char * method)70   void count(const char* method) {
71     Guard m(lock_);
72     int ct = counts_[method];
73     counts_[method] = ++ct;
74   }
75 
echoVoid()76   void echoVoid() override {
77     count("echoVoid");
78     return;
79   }
80 
getCount()81   count_map getCount() {
82     Guard m(lock_);
83     return counts_;
84   }
85 
echoByte(const int8_t arg)86   int8_t echoByte(const int8_t arg) override { return arg; }
echoI32(const int32_t arg)87   int32_t echoI32(const int32_t arg) override { return arg; }
echoI64(const int64_t arg)88   int64_t echoI64(const int64_t arg) override { return arg; }
echoString(string & out,const string & arg)89   void echoString(string& out, const string& arg) override {
90     if (arg != "hello") {
91       T_ERROR_ABORT("WRONG STRING (%s)!!!!", arg.c_str());
92     }
93     out = arg;
94   }
echoList(vector<int8_t> & out,const vector<int8_t> & arg)95   void echoList(vector<int8_t>& out, const vector<int8_t>& arg) override { out = arg; }
echoSet(set<int8_t> & out,const set<int8_t> & arg)96   void echoSet(set<int8_t>& out, const set<int8_t>& arg) override { out = arg; }
echoMap(map<int8_t,int8_t> & out,const map<int8_t,int8_t> & arg)97   void echoMap(map<int8_t, int8_t>& out, const map<int8_t, int8_t>& arg) override { out = arg; }
98 
99 private:
100   count_map counts_;
101   Mutex lock_;
102 };
103 
104 enum TransportOpenCloseBehavior {
105   OpenAndCloseTransportInThread,
106   DontOpenAndCloseTransportInThread
107 };
108 class ClientThread : public Runnable {
109 public:
ClientThread(std::shared_ptr<TTransport> transport,std::shared_ptr<ServiceIf> client,Monitor & monitor,size_t & workerCount,size_t loopCount,TType loopType,TransportOpenCloseBehavior behavior)110   ClientThread(std::shared_ptr<TTransport> transport,
111                std::shared_ptr<ServiceIf> client,
112                Monitor& monitor,
113                size_t& workerCount,
114                size_t loopCount,
115                TType loopType,
116                TransportOpenCloseBehavior behavior)
117     : _transport(transport),
118       _client(client),
119       _monitor(monitor),
120       _workerCount(workerCount),
121       _loopCount(loopCount),
122       _loopType(loopType),
123       _behavior(behavior) {}
124 
run()125   void run() override {
126 
127     // Wait for all worker threads to start
128 
129     {
130       Synchronized s(_monitor);
131       while (_workerCount == 0) {
132         _monitor.wait();
133       }
134     }
135 
136     _startTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
137     if(_behavior == OpenAndCloseTransportInThread) {
138       _transport->open();
139     }
140 
141     switch (_loopType) {
142     case T_VOID:
143       loopEchoVoid();
144       break;
145     case T_BYTE:
146       loopEchoByte();
147       break;
148     case T_I32:
149       loopEchoI32();
150       break;
151     case T_I64:
152       loopEchoI64();
153       break;
154     case T_STRING:
155       loopEchoString();
156       break;
157     default:
158       cerr << "Unexpected loop type" << _loopType << endl;
159       break;
160     }
161 
162     _endTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
163 
164     if(_behavior == OpenAndCloseTransportInThread) {
165       _transport->close();
166     }
167 
168     _done = true;
169 
170     {
171       Synchronized s(_monitor);
172 
173       _workerCount--;
174 
175       if (_workerCount == 0) {
176 
177         _monitor.notify();
178       }
179     }
180   }
181 
loopEchoVoid()182   void loopEchoVoid() {
183     for (size_t ix = 0; ix < _loopCount; ix++) {
184       _client->echoVoid();
185     }
186   }
187 
loopEchoByte()188   void loopEchoByte() {
189     for (size_t ix = 0; ix < _loopCount; ix++) {
190       int8_t arg = 1;
191       int8_t result;
192       result = _client->echoByte(arg);
193       (void)result;
194       assert(result == arg);
195     }
196   }
197 
loopEchoI32()198   void loopEchoI32() {
199     for (size_t ix = 0; ix < _loopCount; ix++) {
200       int32_t arg = 1;
201       int32_t result;
202       result = _client->echoI32(arg);
203       (void)result;
204       assert(result == arg);
205     }
206   }
207 
loopEchoI64()208   void loopEchoI64() {
209     for (size_t ix = 0; ix < _loopCount; ix++) {
210       int64_t arg = 1;
211       int64_t result;
212       result = _client->echoI64(arg);
213       (void)result;
214       assert(result == arg);
215     }
216   }
217 
loopEchoString()218   void loopEchoString() {
219     for (size_t ix = 0; ix < _loopCount; ix++) {
220       string arg = "hello";
221       string result;
222       _client->echoString(result, arg);
223       assert(result == arg);
224     }
225   }
226 
227   std::shared_ptr<TTransport> _transport;
228   std::shared_ptr<ServiceIf> _client;
229   Monitor& _monitor;
230   size_t& _workerCount;
231   size_t _loopCount;
232   TType _loopType;
233   int64_t _startTime;
234   int64_t _endTime;
235   bool _done;
236   Monitor _sleep;
237   TransportOpenCloseBehavior _behavior;
238 };
239 
240 class TStartObserver : public apache::thrift::server::TServerEventHandler {
241 public:
TStartObserver()242   TStartObserver() : awake_(false) {}
preServe()243   void preServe() override {
244     apache::thrift::concurrency::Synchronized s(m_);
245     awake_ = true;
246     m_.notifyAll();
247   }
waitForService()248   void waitForService() {
249     apache::thrift::concurrency::Synchronized s(m_);
250     while (!awake_)
251       m_.waitForever();
252   }
253 
254 private:
255   apache::thrift::concurrency::Monitor m_;
256   bool awake_;
257 };
258 
main(int argc,char ** argv)259 int main(int argc, char** argv) {
260 #if _WIN32
261   transport::TWinsockSingleton::create();
262 #endif
263 
264   int port = 9091;
265   string clientType = "regular";
266   string serverType = "thread-pool";
267   string protocolType = "binary";
268   size_t workerCount = 8;
269   size_t clientCount = 4;
270   size_t loopCount = 50000;
271   TType loopType = T_VOID;
272   string callName = "echoVoid";
273   bool runServer = true;
274   bool logRequests = false;
275   string requestLogPath = "./requestlog.tlog";
276   bool replayRequests = false;
277 
278   ostringstream usage;
279 
280   usage << argv[0] << " [--port=<port number>] [--server] [--server-type=<server-type>] "
281                       "[--protocol-type=<protocol-type>] [--workers=<worker-count>] "
282                       "[--clients=<client-count>] [--loop=<loop-count>] "
283                       "[--client-type=<client-type>]" << endl
284         << "\tclients        Number of client threads to create - 0 implies no clients, i.e. "
285                             "server only.  Default is " << clientCount << endl
286         << "\thelp           Prints this help text." << endl
287         << "\tcall           Service method to call.  Default is " << callName << endl
288         << "\tloop           The number of remote thrift calls each client makes.  Default is " << loopCount << endl
289         << "\tport           The port the server and clients should bind to "
290                             "for thrift network connections.  Default is " << port << endl
291         << "\tserver         Run the Thrift server in this process.  Default is " << runServer << endl
292         << "\tserver-type    Type of server, \"simple\" or \"thread-pool\".  Default is " << serverType << endl
293         << "\tprotocol-type  Type of protocol, \"binary\", \"ascii\", or \"xml\".  Default is " << protocolType << endl
294         << "\tlog-request    Log all request to ./requestlog.tlog. Default is " << logRequests << endl
295         << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl
296         << "\tworkers        Number of thread pools workers.  Only valid "
297                             "for thread-pool server type.  Default is " << workerCount << endl
298         << "\tclient-type    Type of client, \"regular\" or \"concurrent\".  Default is " << clientType << endl
299         << endl;
300 
301   map<string, string> args;
302 
303   for (int ix = 1; ix < argc; ix++) {
304 
305     string arg(argv[ix]);
306 
307     if (arg.compare(0, 2, "--") == 0) {
308 
309       size_t end = arg.find_first_of("=", 2);
310 
311       string key = string(arg, 2, end - 2);
312 
313       if (end != string::npos) {
314         args[key] = string(arg, end + 1);
315       } else {
316         args[key] = "true";
317       }
318     } else {
319       throw invalid_argument("Unexcepted command line token: " + arg);
320     }
321   }
322 
323   try {
324 
325     if (!args["clients"].empty()) {
326       clientCount = atoi(args["clients"].c_str());
327     }
328 
329     if (!args["help"].empty()) {
330       cerr << usage.str();
331       return 0;
332     }
333 
334     if (!args["loop"].empty()) {
335       loopCount = atoi(args["loop"].c_str());
336     }
337 
338     if (!args["call"].empty()) {
339       callName = args["call"];
340     }
341 
342     if (!args["port"].empty()) {
343       port = atoi(args["port"].c_str());
344     }
345 
346     if (!args["server"].empty()) {
347       runServer = args["server"] == "true";
348     }
349 
350     if (!args["log-request"].empty()) {
351       logRequests = args["log-request"] == "true";
352     }
353 
354     if (!args["replay-request"].empty()) {
355       replayRequests = args["replay-request"] == "true";
356     }
357 
358     if (!args["server-type"].empty()) {
359       serverType = args["server-type"];
360 
361       if (serverType == "simple") {
362 
363       } else if (serverType == "thread-pool") {
364 
365       } else if (serverType == "threaded") {
366 
367       } else {
368 
369         throw invalid_argument("Unknown server type " + serverType);
370       }
371     }
372     if (!args["client-type"].empty()) {
373       clientType = args["client-type"];
374 
375       if (clientType == "regular") {
376 
377       } else if (clientType == "concurrent") {
378 
379       } else {
380 
381         throw invalid_argument("Unknown client type " + clientType);
382       }
383     }
384     if (!args["workers"].empty()) {
385       workerCount = atoi(args["workers"].c_str());
386     }
387 
388   } catch (std::exception& e) {
389     cerr << e.what() << endl;
390     cerr << usage.str();
391   }
392 
393   std::shared_ptr<ThreadFactory> threadFactory
394       = std::shared_ptr<ThreadFactory>(new ThreadFactory());
395 
396   // Dispatcher
397   std::shared_ptr<Server> serviceHandler(new Server());
398 
399   if (replayRequests) {
400     std::shared_ptr<Server> serviceHandler(new Server());
401     std::shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
402 
403     // Transports
404     std::shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
405     fileTransport->setChunkSize(2 * 1024 * 1024);
406     fileTransport->setMaxEventSize(1024 * 16);
407     fileTransport->seekToEnd();
408 
409     // Protocol Factory
410     std::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
411 
412     TFileProcessor fileProcessor(serviceProcessor, protocolFactory, fileTransport);
413 
414     fileProcessor.process(0, true);
415     exit(0);
416   }
417 
418   if (runServer) {
419 
420     std::shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
421 
422     // Transport
423     std::shared_ptr<TServerSocket> serverSocket(new TServerSocket(port));
424 
425     // Transport Factory
426     std::shared_ptr<TTransportFactory> transportFactory(new TBufferedTransportFactory());
427 
428     // Protocol Factory
429     std::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
430 
431     if (logRequests) {
432       // initialize the log file
433       std::shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
434       fileTransport->setChunkSize(2 * 1024 * 1024);
435       fileTransport->setMaxEventSize(1024 * 16);
436 
437       transportFactory
438           = std::shared_ptr<TTransportFactory>(new TPipedTransportFactory(fileTransport));
439     }
440 
441     std::shared_ptr<TServer> server;
442 
443     if (serverType == "simple") {
444 
445       server.reset(
446           new TSimpleServer(serviceProcessor, serverSocket, transportFactory, protocolFactory));
447 
448     } else if (serverType == "threaded") {
449 
450       server.reset(
451           new TThreadedServer(serviceProcessor, serverSocket, transportFactory, protocolFactory));
452 
453     } else if (serverType == "thread-pool") {
454 
455       std::shared_ptr<ThreadManager> threadManager
456           = ThreadManager::newSimpleThreadManager(workerCount);
457 
458       threadManager->threadFactory(threadFactory);
459       threadManager->start();
460       server.reset(new TThreadPoolServer(serviceProcessor,
461                                          serverSocket,
462                                          transportFactory,
463                                          protocolFactory,
464                                          threadManager));
465     }
466 
467     std::shared_ptr<TStartObserver> observer(new TStartObserver);
468     server->setServerEventHandler(observer);
469     std::shared_ptr<Thread> serverThread = threadFactory->newThread(server);
470 
471     cerr << "Starting the server on port " << port << endl;
472 
473     serverThread->start();
474     observer->waitForService();
475 
476     // If we aren't running clients, just wait forever for external clients
477     if (clientCount == 0) {
478       serverThread->join();
479     }
480   }
481 
482   if (clientCount > 0) { //FIXME: start here for client type?
483 
484     Monitor monitor;
485 
486     size_t threadCount = 0;
487 
488     set<std::shared_ptr<Thread> > clientThreads;
489 
490     if (callName == "echoVoid") {
491       loopType = T_VOID;
492     } else if (callName == "echoByte") {
493       loopType = T_BYTE;
494     } else if (callName == "echoI32") {
495       loopType = T_I32;
496     } else if (callName == "echoI64") {
497       loopType = T_I64;
498     } else if (callName == "echoString") {
499       loopType = T_STRING;
500     } else {
501       throw invalid_argument("Unknown service call " + callName);
502     }
503 
504     if(clientType == "regular") {
505       for (size_t ix = 0; ix < clientCount; ix++) {
506 
507         std::shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
508         std::shared_ptr<TBufferedTransport> bufferedSocket(new TBufferedTransport(socket, 2048));
509         std::shared_ptr<TProtocol> protocol(new TBinaryProtocol(bufferedSocket));
510         std::shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
511 
512         clientThreads.insert(threadFactory->newThread(std::shared_ptr<ClientThread>(
513             new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, OpenAndCloseTransportInThread))));
514       }
515     } else if(clientType == "concurrent") {
516       std::shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port));
517       std::shared_ptr<TBufferedTransport> bufferedSocket(new TBufferedTransport(socket, 2048));
518       std::shared_ptr<TProtocol> protocol(new TBinaryProtocol(bufferedSocket));
519       auto sync = std::make_shared<TConcurrentClientSyncInfo>();
520       std::shared_ptr<ServiceConcurrentClient> serviceClient(new ServiceConcurrentClient(protocol, sync));
521       socket->open();
522       for (size_t ix = 0; ix < clientCount; ix++) {
523         clientThreads.insert(threadFactory->newThread(std::shared_ptr<ClientThread>(
524             new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, DontOpenAndCloseTransportInThread))));
525       }
526     }
527 
528     for (auto thread = clientThreads.begin();
529          thread != clientThreads.end();
530          thread++) {
531       (*thread)->start();
532     }
533 
534     int64_t time00;
535     int64_t time01;
536 
537     {
538       Synchronized s(monitor);
539       threadCount = clientCount;
540 
541       cerr << "Launch " << clientCount << " " << clientType << " client threads" << endl;
542 
543       time00 = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
544 
545       monitor.notifyAll();
546 
547       while (threadCount > 0) {
548         monitor.wait();
549       }
550 
551       time01 = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
552     }
553 
554     int64_t firstTime = 9223372036854775807LL;
555     int64_t lastTime = 0;
556 
557     double averageTime = 0;
558     int64_t minTime = 9223372036854775807LL;
559     int64_t maxTime = 0;
560 
561     for (auto ix = clientThreads.begin();
562          ix != clientThreads.end();
563          ix++) {
564 
565       std::shared_ptr<ClientThread> client
566           = std::dynamic_pointer_cast<ClientThread>((*ix)->runnable());
567 
568       int64_t delta = client->_endTime - client->_startTime;
569 
570       assert(delta > 0);
571 
572       if (client->_startTime < firstTime) {
573         firstTime = client->_startTime;
574       }
575 
576       if (client->_endTime > lastTime) {
577         lastTime = client->_endTime;
578       }
579 
580       if (delta < minTime) {
581         minTime = delta;
582       }
583 
584       if (delta > maxTime) {
585         maxTime = delta;
586       }
587 
588       averageTime += delta;
589     }
590 
591     averageTime /= clientCount;
592 
593     cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount
594          << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl;
595 
596     count_map count = serviceHandler->getCount();
597     count_map::iterator iter;
598     for (iter = count.begin(); iter != count.end(); ++iter) {
599       printf("%s => %d\n", iter->first, iter->second);
600     }
601     cerr << "done." << endl;
602   }
603 
604   return 0;
605 }
606