1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <thrift/concurrency/ThreadManager.h>
21 #include <thrift/concurrency/ThreadFactory.h>
22 #include <thrift/concurrency/Monitor.h>
23 #include <thrift/concurrency/Mutex.h>
24 #include <thrift/protocol/TBinaryProtocol.h>
25 #include <thrift/server/TSimpleServer.h>
26 #include <thrift/server/TThreadPoolServer.h>
27 #include <thrift/server/TThreadedServer.h>
28 #include <thrift/server/TNonblockingServer.h>
29 #include <thrift/transport/TServerSocket.h>
30 #include <thrift/transport/TSocket.h>
31 #include <thrift/transport/TNonblockingServerSocket.h>
32 #include <thrift/transport/TTransportUtils.h>
33 #include <thrift/transport/TFileTransport.h>
34 #include <thrift/TLogging.h>
35 
36 #include "Service.h"
37 
38 #include <iostream>
39 #include <set>
40 #include <stdexcept>
41 #include <sstream>
42 #include <map>
43 #if _WIN32
44 #include <thrift/windows/TWinsockSingleton.h>
45 #endif
46 
47 using namespace std;
48 
49 using namespace apache::thrift;
50 using namespace apache::thrift::protocol;
51 using namespace apache::thrift::transport;
52 using namespace apache::thrift::server;
53 using namespace apache::thrift::concurrency;
54 
55 using namespace test::stress;
56 
57 struct eqstr {
operator ()eqstr58   bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) == 0; }
59 };
60 
61 struct ltstr {
operator ()ltstr62   bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) < 0; }
63 };
64 
65 // typedef hash_map<const char*, int, hash<const char*>, eqstr> count_map;
66 typedef map<const char*, int, ltstr> count_map;
67 
68 class Server : public ServiceIf {
69 public:
70   Server() = default;
71 
count(const char * method)72   void count(const char* method) {
73     Guard m(lock_);
74     int ct = counts_[method];
75     counts_[method] = ++ct;
76   }
77 
echoVoid()78   void echoVoid() override {
79     count("echoVoid");
80     // Sleep to simulate work
81     THRIFT_SLEEP_USEC(1);
82     return;
83   }
84 
getCount()85   count_map getCount() {
86     Guard m(lock_);
87     return counts_;
88   }
89 
echoByte(const int8_t arg)90   int8_t echoByte(const int8_t arg) override { return arg; }
echoI32(const int32_t arg)91   int32_t echoI32(const int32_t arg) override { return arg; }
echoI64(const int64_t arg)92   int64_t echoI64(const int64_t arg) override { return arg; }
echoString(string & out,const string & arg)93   void echoString(string& out, const string& arg) override {
94     if (arg != "hello") {
95       T_ERROR_ABORT("WRONG STRING (%s)!!!!", arg.c_str());
96     }
97     out = arg;
98   }
echoList(vector<int8_t> & out,const vector<int8_t> & arg)99   void echoList(vector<int8_t>& out, const vector<int8_t>& arg) override { out = arg; }
echoSet(set<int8_t> & out,const set<int8_t> & arg)100   void echoSet(set<int8_t>& out, const set<int8_t>& arg) override { out = arg; }
echoMap(map<int8_t,int8_t> & out,const map<int8_t,int8_t> & arg)101   void echoMap(map<int8_t, int8_t>& out, const map<int8_t, int8_t>& arg) override { out = arg; }
102 
103 private:
104   count_map counts_;
105   Mutex lock_;
106 };
107 
108 class ClientThread : public Runnable {
109 public:
ClientThread(std::shared_ptr<TTransport> transport,std::shared_ptr<ServiceClient> client,Monitor & monitor,size_t & workerCount,size_t loopCount,TType loopType)110   ClientThread(std::shared_ptr<TTransport> transport,
111                std::shared_ptr<ServiceClient> client,
112                Monitor& monitor,
113                size_t& workerCount,
114                size_t loopCount,
115                TType loopType)
116     : _transport(transport),
117       _client(client),
118       _monitor(monitor),
119       _workerCount(workerCount),
120       _loopCount(loopCount),
121       _loopType(loopType) {}
122 
run()123   void run() override {
124 
125     // Wait for all worker threads to start
126 
127     {
128       Synchronized s(_monitor);
129       while (_workerCount == 0) {
130         _monitor.wait();
131       }
132     }
133 
134     _startTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
135 
136     _transport->open();
137 
138     switch (_loopType) {
139     case T_VOID:
140       loopEchoVoid();
141       break;
142     case T_BYTE:
143       loopEchoByte();
144       break;
145     case T_I32:
146       loopEchoI32();
147       break;
148     case T_I64:
149       loopEchoI64();
150       break;
151     case T_STRING:
152       loopEchoString();
153       break;
154     default:
155       cerr << "Unexpected loop type" << _loopType << endl;
156       break;
157     }
158 
159     _endTime = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
160 
161     _transport->close();
162 
163     _done = true;
164 
165     {
166       Synchronized s(_monitor);
167 
168       _workerCount--;
169 
170       if (_workerCount == 0) {
171 
172         _monitor.notify();
173       }
174     }
175   }
176 
loopEchoVoid()177   void loopEchoVoid() {
178     for (size_t ix = 0; ix < _loopCount; ix++) {
179       _client->echoVoid();
180     }
181   }
182 
loopEchoByte()183   void loopEchoByte() {
184     for (size_t ix = 0; ix < _loopCount; ix++) {
185       int8_t arg = 1;
186       int8_t result;
187       result = _client->echoByte(arg);
188       (void)result;
189       assert(result == arg);
190     }
191   }
192 
loopEchoI32()193   void loopEchoI32() {
194     for (size_t ix = 0; ix < _loopCount; ix++) {
195       int32_t arg = 1;
196       int32_t result;
197       result = _client->echoI32(arg);
198       (void)result;
199       assert(result == arg);
200     }
201   }
202 
loopEchoI64()203   void loopEchoI64() {
204     for (size_t ix = 0; ix < _loopCount; ix++) {
205       int64_t arg = 1;
206       int64_t result;
207       result = _client->echoI64(arg);
208       (void)result;
209       assert(result == arg);
210     }
211   }
212 
loopEchoString()213   void loopEchoString() {
214     for (size_t ix = 0; ix < _loopCount; ix++) {
215       string arg = "hello";
216       string result;
217       _client->echoString(result, arg);
218       assert(result == arg);
219     }
220   }
221 
222   std::shared_ptr<TTransport> _transport;
223   std::shared_ptr<ServiceClient> _client;
224   Monitor& _monitor;
225   size_t& _workerCount;
226   size_t _loopCount;
227   TType _loopType;
228   int64_t _startTime;
229   int64_t _endTime;
230   bool _done;
231   Monitor _sleep;
232 };
233 
main(int argc,char ** argv)234 int main(int argc, char** argv) {
235 #if _WIN32
236   transport::TWinsockSingleton::create();
237 #endif
238 
239   int port = 9091;
240   string serverType = "simple";
241   string protocolType = "binary";
242   uint32_t workerCount = 4;
243   uint32_t clientCount = 20;
244   uint32_t loopCount = 1000;
245   TType loopType = T_VOID;
246   string callName = "echoVoid";
247   bool runServer = true;
248   bool logRequests = false;
249   string requestLogPath = "./requestlog.tlog";
250   bool replayRequests = false;
251 
252   ostringstream usage;
253 
254   usage << argv[0] << " [--port=<port number>] [--server] [--server-type=<server-type>] "
255                       "[--protocol-type=<protocol-type>] [--workers=<worker-count>] "
256                       "[--clients=<client-count>] [--loop=<loop-count>]" << endl
257         << "\tclients        Number of client threads to create - 0 implies no clients, i.e. "
258            "server only.  Default is " << clientCount << endl
259         << "\thelp           Prints this help text." << endl
260         << "\tcall           Service method to call.  Default is " << callName << endl
261         << "\tloop           The number of remote thrift calls each client makes.  Default is "
262         << loopCount << endl << "\tport           The port the server and clients should bind to "
263                                 "for thrift network connections.  Default is " << port << endl
264         << "\tserver         Run the Thrift server in this process.  Default is " << runServer
265         << endl << "\tserver-type    Type of server, \"simple\" or \"thread-pool\".  Default is "
266         << serverType << endl
267         << "\tprotocol-type  Type of protocol, \"binary\", \"ascii\", or \"xml\".  Default is "
268         << protocolType << endl
269         << "\tlog-request    Log all request to ./requestlog.tlog. Default is " << logRequests
270         << endl << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is "
271         << replayRequests << endl << "\tworkers        Number of thread pools workers.  Only valid "
272                                      "for thread-pool server type.  Default is " << workerCount
273         << endl;
274 
275   map<string, string> args;
276 
277   for (int ix = 1; ix < argc; ix++) {
278 
279     string arg(argv[ix]);
280 
281     if (arg.compare(0, 2, "--") == 0) {
282 
283       size_t end = arg.find_first_of("=", 2);
284 
285       string key = string(arg, 2, end - 2);
286 
287       if (end != string::npos) {
288         args[key] = string(arg, end + 1);
289       } else {
290         args[key] = "true";
291       }
292     } else {
293       throw invalid_argument("Unexcepted command line token: " + arg);
294     }
295   }
296 
297   try {
298 
299     if (!args["clients"].empty()) {
300       clientCount = atoi(args["clients"].c_str());
301     }
302 
303     if (!args["help"].empty()) {
304       cerr << usage.str();
305       return 0;
306     }
307 
308     if (!args["loop"].empty()) {
309       loopCount = atoi(args["loop"].c_str());
310     }
311 
312     if (!args["call"].empty()) {
313       callName = args["call"];
314     }
315 
316     if (!args["port"].empty()) {
317       port = atoi(args["port"].c_str());
318     }
319 
320     if (!args["server"].empty()) {
321       runServer = args["server"] == "true";
322     }
323 
324     if (!args["log-request"].empty()) {
325       logRequests = args["log-request"] == "true";
326     }
327 
328     if (!args["replay-request"].empty()) {
329       replayRequests = args["replay-request"] == "true";
330     }
331 
332     if (!args["server-type"].empty()) {
333       serverType = args["server-type"];
334     }
335 
336     if (!args["workers"].empty()) {
337       workerCount = atoi(args["workers"].c_str());
338     }
339 
340   } catch (std::exception& e) {
341     cerr << e.what() << endl;
342     cerr << usage.str();
343   }
344 
345   std::shared_ptr<ThreadFactory> threadFactory
346       = std::shared_ptr<ThreadFactory>(new ThreadFactory());
347 
348   // Dispatcher
349   std::shared_ptr<Server> serviceHandler(new Server());
350 
351   if (replayRequests) {
352     std::shared_ptr<Server> serviceHandler(new Server());
353     std::shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
354 
355     // Transports
356     std::shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
357     fileTransport->setChunkSize(2 * 1024 * 1024);
358     fileTransport->setMaxEventSize(1024 * 16);
359     fileTransport->seekToEnd();
360 
361     // Protocol Factory
362     std::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
363 
364     TFileProcessor fileProcessor(serviceProcessor, protocolFactory, fileTransport);
365 
366     fileProcessor.process(0, true);
367     exit(0);
368   }
369 
370   if (runServer) {
371 
372     std::shared_ptr<ServiceProcessor> serviceProcessor(new ServiceProcessor(serviceHandler));
373 
374     // Protocol Factory
375     std::shared_ptr<TProtocolFactory> protocolFactory(new TBinaryProtocolFactory());
376 
377     // Transport Factory
378     std::shared_ptr<TTransportFactory> transportFactory;
379 
380     if (logRequests) {
381       // initialize the log file
382       std::shared_ptr<TFileTransport> fileTransport(new TFileTransport(requestLogPath));
383       fileTransport->setChunkSize(2 * 1024 * 1024);
384       fileTransport->setMaxEventSize(1024 * 16);
385 
386       transportFactory
387           = std::shared_ptr<TTransportFactory>(new TPipedTransportFactory(fileTransport));
388     }
389 
390     std::shared_ptr<Thread> serverThread;
391     std::shared_ptr<Thread> serverThread2;
392     std::shared_ptr<transport::TNonblockingServerSocket> nbSocket1;
393     std::shared_ptr<transport::TNonblockingServerSocket> nbSocket2;
394 
395     if (serverType == "simple") {
396 
397       nbSocket1.reset(new transport::TNonblockingServerSocket(port));
398       serverThread = threadFactory->newThread(std::shared_ptr<TServer>(
399           new TNonblockingServer(serviceProcessor, protocolFactory, nbSocket1)));
400       nbSocket2.reset(new transport::TNonblockingServerSocket(port + 1));
401       serverThread2 = threadFactory->newThread(std::shared_ptr<TServer>(
402           new TNonblockingServer(serviceProcessor, protocolFactory, nbSocket2)));
403 
404     } else if (serverType == "thread-pool") {
405 
406       std::shared_ptr<ThreadManager> threadManager
407           = ThreadManager::newSimpleThreadManager(workerCount);
408 
409       threadManager->threadFactory(threadFactory);
410       threadManager->start();
411       nbSocket1.reset(new transport::TNonblockingServerSocket(port));
412       serverThread = threadFactory->newThread(std::shared_ptr<TServer>(
413           new TNonblockingServer(serviceProcessor, protocolFactory, nbSocket1, threadManager)));
414       nbSocket2.reset(new transport::TNonblockingServerSocket(port + 1));
415       serverThread2 = threadFactory->newThread(std::shared_ptr<TServer>(
416           new TNonblockingServer(serviceProcessor, protocolFactory, nbSocket2, threadManager)));
417     }
418 
419     cerr << "Starting the server on port " << port << " and " << (port + 1) << endl;
420     serverThread->start();
421     serverThread2->start();
422 
423     // If we aren't running clients, just wait forever for external clients
424 
425     if (clientCount == 0) {
426       serverThread->join();
427       serverThread2->join();
428     }
429   }
430   THRIFT_SLEEP_SEC(1);
431 
432   if (clientCount > 0) {
433 
434     Monitor monitor;
435 
436     size_t threadCount = 0;
437 
438     set<std::shared_ptr<Thread> > clientThreads;
439 
440     if (callName == "echoVoid") {
441       loopType = T_VOID;
442     } else if (callName == "echoByte") {
443       loopType = T_BYTE;
444     } else if (callName == "echoI32") {
445       loopType = T_I32;
446     } else if (callName == "echoI64") {
447       loopType = T_I64;
448     } else if (callName == "echoString") {
449       loopType = T_STRING;
450     } else {
451       throw invalid_argument("Unknown service call " + callName);
452     }
453 
454     for (uint32_t ix = 0; ix < clientCount; ix++) {
455 
456       std::shared_ptr<TSocket> socket(new TSocket("127.0.0.1", port + (ix % 2)));
457       std::shared_ptr<TFramedTransport> framedSocket(new TFramedTransport(socket));
458       std::shared_ptr<TProtocol> protocol(new TBinaryProtocol(framedSocket));
459       std::shared_ptr<ServiceClient> serviceClient(new ServiceClient(protocol));
460 
461       clientThreads.insert(threadFactory->newThread(std::shared_ptr<ClientThread>(
462           new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType))));
463     }
464 
465     for (auto thread = clientThreads.begin();
466          thread != clientThreads.end();
467          thread++) {
468       (*thread)->start();
469     }
470 
471     int64_t time00;
472     int64_t time01;
473 
474     {
475       Synchronized s(monitor);
476       threadCount = clientCount;
477 
478       cerr << "Launch " << clientCount << " client threads" << endl;
479 
480       time00 = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
481 
482       monitor.notifyAll();
483 
484       while (threadCount > 0) {
485         monitor.wait();
486       }
487 
488       time01 = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now().time_since_epoch()).count();
489     }
490 
491     int64_t firstTime = 9223372036854775807LL;
492     int64_t lastTime = 0;
493 
494     double averageTime = 0;
495     int64_t minTime = 9223372036854775807LL;
496     int64_t maxTime = 0;
497 
498     for (auto ix = clientThreads.begin();
499          ix != clientThreads.end();
500          ix++) {
501 
502       std::shared_ptr<ClientThread> client
503           = std::dynamic_pointer_cast<ClientThread>((*ix)->runnable());
504 
505       int64_t delta = client->_endTime - client->_startTime;
506 
507       assert(delta > 0);
508 
509       if (client->_startTime < firstTime) {
510         firstTime = client->_startTime;
511       }
512 
513       if (client->_endTime > lastTime) {
514         lastTime = client->_endTime;
515       }
516 
517       if (delta < minTime) {
518         minTime = delta;
519       }
520 
521       if (delta > maxTime) {
522         maxTime = delta;
523       }
524 
525       averageTime += delta;
526     }
527 
528     averageTime /= clientCount;
529 
530     cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount
531          << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl;
532 
533     count_map count = serviceHandler->getCount();
534     count_map::iterator iter;
535     for (iter = count.begin(); iter != count.end(); ++iter) {
536       printf("%s => %d\n", iter->first, iter->second);
537     }
538     cerr << "done." << endl;
539   }
540 
541   return 0;
542 }
543