/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "Service.h" #include #include #include #include #include #if _WIN32 #include #endif using namespace std; using namespace apache::thrift; using namespace apache::thrift::async; using namespace apache::thrift::protocol; using namespace apache::thrift::transport; using namespace apache::thrift::server; using namespace apache::thrift::concurrency; using namespace test::stress; struct eqstr { bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) == 0; } }; struct ltstr { bool operator()(const char* s1, const char* s2) const { return strcmp(s1, s2) < 0; } }; // typedef hash_map, eqstr> count_map; typedef map count_map; class Server : public ServiceIf { public: Server() = default; void count(const char* method) { Guard m(lock_); int ct = counts_[method]; counts_[method] = ++ct; } void echoVoid() override { count("echoVoid"); return; } count_map getCount() { Guard m(lock_); return counts_; } int8_t echoByte(const int8_t arg) override { return arg; } int32_t echoI32(const int32_t arg) override { return arg; } int64_t echoI64(const int64_t arg) override { return arg; } void echoString(string& out, const string& arg) override { if (arg != "hello") { T_ERROR_ABORT("WRONG STRING (%s)!!!!", arg.c_str()); } out = arg; } void echoList(vector& out, const vector& arg) override { out = arg; } void echoSet(set& out, const set& arg) override { out = arg; } void echoMap(map& out, const map& arg) override { out = arg; } private: count_map counts_; Mutex lock_; }; enum TransportOpenCloseBehavior { OpenAndCloseTransportInThread, DontOpenAndCloseTransportInThread }; class ClientThread : public Runnable { public: ClientThread(std::shared_ptr transport, std::shared_ptr client, Monitor& monitor, size_t& workerCount, size_t loopCount, TType loopType, TransportOpenCloseBehavior behavior) : _transport(transport), _client(client), _monitor(monitor), _workerCount(workerCount), _loopCount(loopCount), _loopType(loopType), _behavior(behavior) {} void run() override { // Wait for all worker threads to start { Synchronized s(_monitor); while (_workerCount == 0) { _monitor.wait(); } } _startTime = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); if(_behavior == OpenAndCloseTransportInThread) { _transport->open(); } switch (_loopType) { case T_VOID: loopEchoVoid(); break; case T_BYTE: loopEchoByte(); break; case T_I32: loopEchoI32(); break; case T_I64: loopEchoI64(); break; case T_STRING: loopEchoString(); break; default: cerr << "Unexpected loop type" << _loopType << endl; break; } _endTime = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); if(_behavior == OpenAndCloseTransportInThread) { _transport->close(); } _done = true; { Synchronized s(_monitor); _workerCount--; if (_workerCount == 0) { _monitor.notify(); } } } void loopEchoVoid() { for (size_t ix = 0; ix < _loopCount; ix++) { _client->echoVoid(); } } void loopEchoByte() { for (size_t ix = 0; ix < _loopCount; ix++) { int8_t arg = 1; int8_t result; result = _client->echoByte(arg); (void)result; assert(result == arg); } } void loopEchoI32() { for (size_t ix = 0; ix < _loopCount; ix++) { int32_t arg = 1; int32_t result; result = _client->echoI32(arg); (void)result; assert(result == arg); } } void loopEchoI64() { for (size_t ix = 0; ix < _loopCount; ix++) { int64_t arg = 1; int64_t result; result = _client->echoI64(arg); (void)result; assert(result == arg); } } void loopEchoString() { for (size_t ix = 0; ix < _loopCount; ix++) { string arg = "hello"; string result; _client->echoString(result, arg); assert(result == arg); } } std::shared_ptr _transport; std::shared_ptr _client; Monitor& _monitor; size_t& _workerCount; size_t _loopCount; TType _loopType; int64_t _startTime; int64_t _endTime; bool _done; Monitor _sleep; TransportOpenCloseBehavior _behavior; }; class TStartObserver : public apache::thrift::server::TServerEventHandler { public: TStartObserver() : awake_(false) {} void preServe() override { apache::thrift::concurrency::Synchronized s(m_); awake_ = true; m_.notifyAll(); } void waitForService() { apache::thrift::concurrency::Synchronized s(m_); while (!awake_) m_.waitForever(); } private: apache::thrift::concurrency::Monitor m_; bool awake_; }; int main(int argc, char** argv) { #if _WIN32 transport::TWinsockSingleton::create(); #endif int port = 9091; string clientType = "regular"; string serverType = "thread-pool"; string protocolType = "binary"; size_t workerCount = 8; size_t clientCount = 4; size_t loopCount = 50000; TType loopType = T_VOID; string callName = "echoVoid"; bool runServer = true; bool logRequests = false; string requestLogPath = "./requestlog.tlog"; bool replayRequests = false; ostringstream usage; usage << argv[0] << " [--port=] [--server] [--server-type=] " "[--protocol-type=] [--workers=] " "[--clients=] [--loop=] " "[--client-type=]" << endl << "\tclients Number of client threads to create - 0 implies no clients, i.e. " "server only. Default is " << clientCount << endl << "\thelp Prints this help text." << endl << "\tcall Service method to call. Default is " << callName << endl << "\tloop The number of remote thrift calls each client makes. Default is " << loopCount << endl << "\tport The port the server and clients should bind to " "for thrift network connections. Default is " << port << endl << "\tserver Run the Thrift server in this process. Default is " << runServer << endl << "\tserver-type Type of server, \"simple\" or \"thread-pool\". Default is " << serverType << endl << "\tprotocol-type Type of protocol, \"binary\", \"ascii\", or \"xml\". Default is " << protocolType << endl << "\tlog-request Log all request to ./requestlog.tlog. Default is " << logRequests << endl << "\treplay-request Replay requests from log file (./requestlog.tlog) Default is " << replayRequests << endl << "\tworkers Number of thread pools workers. Only valid " "for thread-pool server type. Default is " << workerCount << endl << "\tclient-type Type of client, \"regular\" or \"concurrent\". Default is " << clientType << endl << endl; map args; for (int ix = 1; ix < argc; ix++) { string arg(argv[ix]); if (arg.compare(0, 2, "--") == 0) { size_t end = arg.find_first_of("=", 2); string key = string(arg, 2, end - 2); if (end != string::npos) { args[key] = string(arg, end + 1); } else { args[key] = "true"; } } else { throw invalid_argument("Unexcepted command line token: " + arg); } } try { if (!args["clients"].empty()) { clientCount = atoi(args["clients"].c_str()); } if (!args["help"].empty()) { cerr << usage.str(); return 0; } if (!args["loop"].empty()) { loopCount = atoi(args["loop"].c_str()); } if (!args["call"].empty()) { callName = args["call"]; } if (!args["port"].empty()) { port = atoi(args["port"].c_str()); } if (!args["server"].empty()) { runServer = args["server"] == "true"; } if (!args["log-request"].empty()) { logRequests = args["log-request"] == "true"; } if (!args["replay-request"].empty()) { replayRequests = args["replay-request"] == "true"; } if (!args["server-type"].empty()) { serverType = args["server-type"]; if (serverType == "simple") { } else if (serverType == "thread-pool") { } else if (serverType == "threaded") { } else { throw invalid_argument("Unknown server type " + serverType); } } if (!args["client-type"].empty()) { clientType = args["client-type"]; if (clientType == "regular") { } else if (clientType == "concurrent") { } else { throw invalid_argument("Unknown client type " + clientType); } } if (!args["workers"].empty()) { workerCount = atoi(args["workers"].c_str()); } } catch (std::exception& e) { cerr << e.what() << endl; cerr << usage.str(); } std::shared_ptr threadFactory = std::shared_ptr(new ThreadFactory()); // Dispatcher std::shared_ptr serviceHandler(new Server()); if (replayRequests) { std::shared_ptr serviceHandler(new Server()); std::shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); // Transports std::shared_ptr fileTransport(new TFileTransport(requestLogPath)); fileTransport->setChunkSize(2 * 1024 * 1024); fileTransport->setMaxEventSize(1024 * 16); fileTransport->seekToEnd(); // Protocol Factory std::shared_ptr protocolFactory(new TBinaryProtocolFactory()); TFileProcessor fileProcessor(serviceProcessor, protocolFactory, fileTransport); fileProcessor.process(0, true); exit(0); } if (runServer) { std::shared_ptr serviceProcessor(new ServiceProcessor(serviceHandler)); // Transport std::shared_ptr serverSocket(new TServerSocket(port)); // Transport Factory std::shared_ptr transportFactory(new TBufferedTransportFactory()); // Protocol Factory std::shared_ptr protocolFactory(new TBinaryProtocolFactory()); if (logRequests) { // initialize the log file std::shared_ptr fileTransport(new TFileTransport(requestLogPath)); fileTransport->setChunkSize(2 * 1024 * 1024); fileTransport->setMaxEventSize(1024 * 16); transportFactory = std::shared_ptr(new TPipedTransportFactory(fileTransport)); } std::shared_ptr server; if (serverType == "simple") { server.reset( new TSimpleServer(serviceProcessor, serverSocket, transportFactory, protocolFactory)); } else if (serverType == "threaded") { server.reset( new TThreadedServer(serviceProcessor, serverSocket, transportFactory, protocolFactory)); } else if (serverType == "thread-pool") { std::shared_ptr threadManager = ThreadManager::newSimpleThreadManager(workerCount); threadManager->threadFactory(threadFactory); threadManager->start(); server.reset(new TThreadPoolServer(serviceProcessor, serverSocket, transportFactory, protocolFactory, threadManager)); } std::shared_ptr observer(new TStartObserver); server->setServerEventHandler(observer); std::shared_ptr serverThread = threadFactory->newThread(server); cerr << "Starting the server on port " << port << endl; serverThread->start(); observer->waitForService(); // If we aren't running clients, just wait forever for external clients if (clientCount == 0) { serverThread->join(); } } if (clientCount > 0) { //FIXME: start here for client type? Monitor monitor; size_t threadCount = 0; set > clientThreads; if (callName == "echoVoid") { loopType = T_VOID; } else if (callName == "echoByte") { loopType = T_BYTE; } else if (callName == "echoI32") { loopType = T_I32; } else if (callName == "echoI64") { loopType = T_I64; } else if (callName == "echoString") { loopType = T_STRING; } else { throw invalid_argument("Unknown service call " + callName); } if(clientType == "regular") { for (size_t ix = 0; ix < clientCount; ix++) { std::shared_ptr socket(new TSocket("127.0.0.1", port)); std::shared_ptr bufferedSocket(new TBufferedTransport(socket, 2048)); std::shared_ptr protocol(new TBinaryProtocol(bufferedSocket)); std::shared_ptr serviceClient(new ServiceClient(protocol)); clientThreads.insert(threadFactory->newThread(std::shared_ptr( new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, OpenAndCloseTransportInThread)))); } } else if(clientType == "concurrent") { std::shared_ptr socket(new TSocket("127.0.0.1", port)); std::shared_ptr bufferedSocket(new TBufferedTransport(socket, 2048)); std::shared_ptr protocol(new TBinaryProtocol(bufferedSocket)); auto sync = std::make_shared(); std::shared_ptr serviceClient(new ServiceConcurrentClient(protocol, sync)); socket->open(); for (size_t ix = 0; ix < clientCount; ix++) { clientThreads.insert(threadFactory->newThread(std::shared_ptr( new ClientThread(socket, serviceClient, monitor, threadCount, loopCount, loopType, DontOpenAndCloseTransportInThread)))); } } for (auto thread = clientThreads.begin(); thread != clientThreads.end(); thread++) { (*thread)->start(); } int64_t time00; int64_t time01; { Synchronized s(monitor); threadCount = clientCount; cerr << "Launch " << clientCount << " " << clientType << " client threads" << endl; time00 = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); monitor.notifyAll(); while (threadCount > 0) { monitor.wait(); } time01 = std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()).count(); } int64_t firstTime = 9223372036854775807LL; int64_t lastTime = 0; double averageTime = 0; int64_t minTime = 9223372036854775807LL; int64_t maxTime = 0; for (auto ix = clientThreads.begin(); ix != clientThreads.end(); ix++) { std::shared_ptr client = std::dynamic_pointer_cast((*ix)->runnable()); int64_t delta = client->_endTime - client->_startTime; assert(delta > 0); if (client->_startTime < firstTime) { firstTime = client->_startTime; } if (client->_endTime > lastTime) { lastTime = client->_endTime; } if (delta < minTime) { minTime = delta; } if (delta > maxTime) { maxTime = delta; } averageTime += delta; } averageTime /= clientCount; cout << "workers :" << workerCount << ", client : " << clientCount << ", loops : " << loopCount << ", rate : " << (clientCount * loopCount * 1000) / ((double)(time01 - time00)) << endl; count_map count = serviceHandler->getCount(); count_map::iterator iter; for (iter = count.begin(); iter != count.end(); ++iter) { printf("%s => %d\n", iter->first, iter->second); } cerr << "done." << endl; } return 0; }