1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #define BOOST_TEST_MODULE TNonblockingSSLServerTest
21 #include <boost/test/unit_test.hpp>
22 #include <boost/filesystem.hpp>
23 #include <boost/format.hpp>
24 
25 #include "thrift/server/TNonblockingServer.h"
26 #include "thrift/transport/TSSLSocket.h"
27 #include "thrift/transport/TNonblockingSSLServerSocket.h"
28 
29 #include "gen-cpp/ParentService.h"
30 
31 #include <event.h>
32 #ifdef HAVE_SIGNAL_H
33 #include <signal.h>
34 #endif
35 
36 using namespace apache::thrift;
37 using apache::thrift::concurrency::Guard;
38 using apache::thrift::concurrency::Monitor;
39 using apache::thrift::concurrency::Mutex;
40 using apache::thrift::server::TServerEventHandler;
41 using apache::thrift::transport::TSSLSocketFactory;
42 using apache::thrift::transport::TSSLSocket;
43 
44 struct Handler : public test::ParentServiceIf {
addStringHandler45   void addString(const std::string& s) override { strings_.push_back(s); }
getStringsHandler46   void getStrings(std::vector<std::string>& _return) override { _return = strings_; }
47   std::vector<std::string> strings_;
48 
49   // dummy overrides not used in this test
incrementGenerationHandler50   int32_t incrementGeneration() override { return 0; }
getGenerationHandler51   int32_t getGeneration() override { return 0; }
getDataWaitHandler52   void getDataWait(std::string&, const int32_t) override {}
onewayWaitHandler53   void onewayWait() override {}
exceptionWaitHandler54   void exceptionWait(const std::string&) override {}
unexpectedExceptionWaitHandler55   void unexpectedExceptionWait(const std::string&) override {}
56 };
57 
58 boost::filesystem::path keyDir;
certFile(const std::string & filename)59 boost::filesystem::path certFile(const std::string& filename)
60 {
61   return keyDir / filename;
62 }
63 
64 struct GlobalFixtureSSL
65 {
GlobalFixtureSSLGlobalFixtureSSL66     GlobalFixtureSSL()
67     {
68       using namespace boost::unit_test::framework;
69       for (int i = 0; i < master_test_suite().argc; ++i)
70       {
71         BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
72       }
73 
74 #ifdef __linux__
75       // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
76       // disconnected can cause a SIGPIPE signal...
77       signal(SIGPIPE, SIG_IGN);
78 #endif
79 
80       TSSLSocketFactory::setManualOpenSSLInitialization(true);
81       apache::thrift::transport::initializeOpenSSL();
82 
83       keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
84       if (!boost::filesystem::exists(certFile("server.crt")))
85       {
86         keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
87         if (!boost::filesystem::exists(certFile("server.crt")))
88         {
89           throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
90         }
91       }
92     }
93 
~GlobalFixtureSSLGlobalFixtureSSL94     virtual ~GlobalFixtureSSL()
95     {
96       apache::thrift::transport::cleanupOpenSSL();
97 #ifdef __linux__
98       signal(SIGPIPE, SIG_DFL);
99 #endif
100     }
101 };
102 
103 #if (BOOST_VERSION >= 105900)
104 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
105 #else
BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)106 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
107 #endif
108 
109 std::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
110   std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
111 
112   pServerSocketFactory.reset(new TSSLSocketFactory());
113   pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
114   pServerSocketFactory->loadCertificate(certFile("server.crt").string().c_str());
115   pServerSocketFactory->loadPrivateKey(certFile("server.key").string().c_str());
116   pServerSocketFactory->server(true);
117   return pServerSocketFactory;
118 }
119 
createClientSocketFactory()120 std::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
121   std::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
122 
123   pClientSocketFactory.reset(new TSSLSocketFactory());
124   pClientSocketFactory->authenticate(true);
125   pClientSocketFactory->loadCertificate(certFile("client.crt").string().c_str());
126   pClientSocketFactory->loadPrivateKey(certFile("client.key").string().c_str());
127   pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").string().c_str());
128   return pClientSocketFactory;
129 }
130 
131 class Fixture {
132 private:
133   struct ListenEventHandler : public TServerEventHandler {
134     public:
ListenEventHandlerFixture::ListenEventHandler135       ListenEventHandler(Mutex* mutex) : listenMonitor_(mutex), ready_(false) {}
136 
preServeFixture::ListenEventHandler137       void preServe() override /* override */ {
138         Guard g(listenMonitor_.mutex());
139         ready_ = true;
140         listenMonitor_.notify();
141       }
142 
143       Monitor listenMonitor_;
144       bool ready_;
145   };
146 
147   struct Runner : public apache::thrift::concurrency::Runnable {
148     int port;
149     std::shared_ptr<event_base> userEventBase;
150     std::shared_ptr<TProcessor> processor;
151     std::shared_ptr<server::TNonblockingServer> server;
152     std::shared_ptr<ListenEventHandler> listenHandler;
153     std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
154     std::shared_ptr<transport::TNonblockingSSLServerSocket> socket;
155     Mutex mutex_;
156 
RunnerFixture::Runner157     Runner():port(0) {
158       listenHandler.reset(new ListenEventHandler(&mutex_));
159     }
160 
runFixture::Runner161     void run() override {
162       // When binding to explicit port, allow retrying to workaround bind failures on ports in use
163       int retryCount = port ? 10 : 0;
164       pServerSocketFactory = createServerSocketFactory();
165       startServer(retryCount);
166     }
167 
readyBarrierFixture::Runner168     void readyBarrier() {
169       // block until server is listening and ready to accept connections
170       Guard g(mutex_);
171       while (!listenHandler->ready_) {
172         listenHandler->listenMonitor_.wait();
173       }
174     }
175   private:
startServerFixture::Runner176     void startServer(int retry_count) {
177       try {
178         socket.reset(new transport::TNonblockingSSLServerSocket(port, pServerSocketFactory));
179         server.reset(new server::TNonblockingServer(processor, socket));
180 	      server->setServerEventHandler(listenHandler);
181         server->setNumIOThreads(1);
182         if (userEventBase) {
183           server->registerEvents(userEventBase.get());
184         }
185         server->serve();
186       } catch (const transport::TTransportException&) {
187         if (retry_count > 0) {
188           ++port;
189           startServer(retry_count - 1);
190         } else {
191           throw;
192         }
193       }
194     }
195   };
196 
197   struct EventDeleter {
operator ()Fixture::EventDeleter198     void operator()(event_base* p) { event_base_free(p); }
199   };
200 
201 protected:
Fixture()202   Fixture() : processor(new test::ParentServiceProcessor(std::make_shared<Handler>())) {}
203 
~Fixture()204   ~Fixture() {
205     if (server) {
206       server->stop();
207     }
208     if (thread) {
209       thread->join();
210     }
211   }
212 
setEventBase(event_base * user_event_base)213   void setEventBase(event_base* user_event_base) {
214     userEventBase_.reset(user_event_base, EventDeleter());
215   }
216 
startServer(int port)217   int startServer(int port) {
218     std::shared_ptr<Runner> runner(new Runner);
219     runner->port = port;
220     runner->processor = processor;
221     runner->userEventBase = userEventBase_;
222 
223     std::unique_ptr<apache::thrift::concurrency::ThreadFactory> threadFactory(
224         new apache::thrift::concurrency::ThreadFactory(false));
225     thread = threadFactory->newThread(runner);
226     thread->start();
227     runner->readyBarrier();
228 
229     server = runner->server;
230     return runner->port;
231   }
232 
canCommunicate(int serverPort)233   bool canCommunicate(int serverPort) {
234     std::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
235     std::shared_ptr<TSSLSocket> socket = pClientSocketFactory->createSocket("localhost", serverPort);
236     socket->open();
237     test::ParentServiceClient client(std::make_shared<protocol::TBinaryProtocol>(
238         std::make_shared<transport::TFramedTransport>(socket)));
239     client.addString("foo");
240     std::vector<std::string> strings;
241     client.getStrings(strings);
242     return strings.size() == 1 && !(strings[0].compare("foo"));
243   }
244 
245 private:
246   std::shared_ptr<event_base> userEventBase_;
247   std::shared_ptr<test::ParentServiceProcessor> processor;
248 protected:
249   std::shared_ptr<server::TNonblockingServer> server;
250 private:
251   std::shared_ptr<apache::thrift::concurrency::Thread> thread;
252 
253 };
254 
255 BOOST_AUTO_TEST_SUITE(TNonblockingSSLServerTest)
256 
BOOST_FIXTURE_TEST_CASE(get_specified_port,Fixture)257 BOOST_FIXTURE_TEST_CASE(get_specified_port, Fixture) {
258   int specified_port = startServer(12345);
259   BOOST_REQUIRE_GE(specified_port, 12345);
260   BOOST_REQUIRE_EQUAL(server->getListenPort(), specified_port);
261   BOOST_CHECK(canCommunicate(specified_port));
262 
263   server->stop();
264 }
265 
BOOST_FIXTURE_TEST_CASE(get_assigned_port,Fixture)266 BOOST_FIXTURE_TEST_CASE(get_assigned_port, Fixture) {
267   int specified_port = startServer(0);
268   BOOST_REQUIRE_EQUAL(specified_port, 0);
269   int assigned_port = server->getListenPort();
270   BOOST_REQUIRE_NE(assigned_port, 0);
271   BOOST_CHECK(canCommunicate(assigned_port));
272 
273   server->stop();
274 }
275 
BOOST_FIXTURE_TEST_CASE(provide_event_base,Fixture)276 BOOST_FIXTURE_TEST_CASE(provide_event_base, Fixture) {
277   event_base* eb = event_base_new();
278   setEventBase(eb);
279   startServer(0);
280 
281   // assert that the server works
282   BOOST_CHECK(canCommunicate(server->getListenPort()));
283 #if LIBEVENT_VERSION_NUMBER > 0x02010400
284   // also assert that the event_base is actually used when it's easy
285   BOOST_CHECK_GT(event_base_get_num_events(eb, EVENT_BASE_COUNT_ADDED), 0);
286 #endif
287 }
288 
289 BOOST_AUTO_TEST_SUITE_END()
290