1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #define BOOST_TEST_MODULE TNonblockingSSLServerTest
21 #include <boost/test/unit_test.hpp>
22 #include <boost/filesystem.hpp>
23 #include <boost/format.hpp>
24
25 #include "thrift/server/TNonblockingServer.h"
26 #include "thrift/transport/TSSLSocket.h"
27 #include "thrift/transport/TNonblockingSSLServerSocket.h"
28
29 #include "gen-cpp/ParentService.h"
30
31 #include <event.h>
32 #ifdef HAVE_SIGNAL_H
33 #include <signal.h>
34 #endif
35
36 using namespace apache::thrift;
37 using apache::thrift::concurrency::Guard;
38 using apache::thrift::concurrency::Monitor;
39 using apache::thrift::concurrency::Mutex;
40 using apache::thrift::server::TServerEventHandler;
41 using apache::thrift::transport::TSSLSocketFactory;
42 using apache::thrift::transport::TSSLSocket;
43
44 struct Handler : public test::ParentServiceIf {
addStringHandler45 void addString(const std::string& s) override { strings_.push_back(s); }
getStringsHandler46 void getStrings(std::vector<std::string>& _return) override { _return = strings_; }
47 std::vector<std::string> strings_;
48
49 // dummy overrides not used in this test
incrementGenerationHandler50 int32_t incrementGeneration() override { return 0; }
getGenerationHandler51 int32_t getGeneration() override { return 0; }
getDataWaitHandler52 void getDataWait(std::string&, const int32_t) override {}
onewayWaitHandler53 void onewayWait() override {}
exceptionWaitHandler54 void exceptionWait(const std::string&) override {}
unexpectedExceptionWaitHandler55 void unexpectedExceptionWait(const std::string&) override {}
56 };
57
58 boost::filesystem::path keyDir;
certFile(const std::string & filename)59 boost::filesystem::path certFile(const std::string& filename)
60 {
61 return keyDir / filename;
62 }
63
64 struct GlobalFixtureSSL
65 {
GlobalFixtureSSLGlobalFixtureSSL66 GlobalFixtureSSL()
67 {
68 using namespace boost::unit_test::framework;
69 for (int i = 0; i < master_test_suite().argc; ++i)
70 {
71 BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]);
72 }
73
74 #ifdef __linux__
75 // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
76 // disconnected can cause a SIGPIPE signal...
77 signal(SIGPIPE, SIG_IGN);
78 #endif
79
80 TSSLSocketFactory::setManualOpenSSLInitialization(true);
81 apache::thrift::transport::initializeOpenSSL();
82
83 keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
84 if (!boost::filesystem::exists(certFile("server.crt")))
85 {
86 keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]);
87 if (!boost::filesystem::exists(certFile("server.crt")))
88 {
89 throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
90 }
91 }
92 }
93
~GlobalFixtureSSLGlobalFixtureSSL94 virtual ~GlobalFixtureSSL()
95 {
96 apache::thrift::transport::cleanupOpenSSL();
97 #ifdef __linux__
98 signal(SIGPIPE, SIG_DFL);
99 #endif
100 }
101 };
102
103 #if (BOOST_VERSION >= 105900)
104 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL);
105 #else
BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)106 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL)
107 #endif
108
109 std::shared_ptr<TSSLSocketFactory> createServerSocketFactory() {
110 std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
111
112 pServerSocketFactory.reset(new TSSLSocketFactory());
113 pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
114 pServerSocketFactory->loadCertificate(certFile("server.crt").string().c_str());
115 pServerSocketFactory->loadPrivateKey(certFile("server.key").string().c_str());
116 pServerSocketFactory->server(true);
117 return pServerSocketFactory;
118 }
119
createClientSocketFactory()120 std::shared_ptr<TSSLSocketFactory> createClientSocketFactory() {
121 std::shared_ptr<TSSLSocketFactory> pClientSocketFactory;
122
123 pClientSocketFactory.reset(new TSSLSocketFactory());
124 pClientSocketFactory->authenticate(true);
125 pClientSocketFactory->loadCertificate(certFile("client.crt").string().c_str());
126 pClientSocketFactory->loadPrivateKey(certFile("client.key").string().c_str());
127 pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").string().c_str());
128 return pClientSocketFactory;
129 }
130
131 class Fixture {
132 private:
133 struct ListenEventHandler : public TServerEventHandler {
134 public:
ListenEventHandlerFixture::ListenEventHandler135 ListenEventHandler(Mutex* mutex) : listenMonitor_(mutex), ready_(false) {}
136
preServeFixture::ListenEventHandler137 void preServe() override /* override */ {
138 Guard g(listenMonitor_.mutex());
139 ready_ = true;
140 listenMonitor_.notify();
141 }
142
143 Monitor listenMonitor_;
144 bool ready_;
145 };
146
147 struct Runner : public apache::thrift::concurrency::Runnable {
148 int port;
149 std::shared_ptr<event_base> userEventBase;
150 std::shared_ptr<TProcessor> processor;
151 std::shared_ptr<server::TNonblockingServer> server;
152 std::shared_ptr<ListenEventHandler> listenHandler;
153 std::shared_ptr<TSSLSocketFactory> pServerSocketFactory;
154 std::shared_ptr<transport::TNonblockingSSLServerSocket> socket;
155 Mutex mutex_;
156
RunnerFixture::Runner157 Runner():port(0) {
158 listenHandler.reset(new ListenEventHandler(&mutex_));
159 }
160
runFixture::Runner161 void run() override {
162 // When binding to explicit port, allow retrying to workaround bind failures on ports in use
163 int retryCount = port ? 10 : 0;
164 pServerSocketFactory = createServerSocketFactory();
165 startServer(retryCount);
166 }
167
readyBarrierFixture::Runner168 void readyBarrier() {
169 // block until server is listening and ready to accept connections
170 Guard g(mutex_);
171 while (!listenHandler->ready_) {
172 listenHandler->listenMonitor_.wait();
173 }
174 }
175 private:
startServerFixture::Runner176 void startServer(int retry_count) {
177 try {
178 socket.reset(new transport::TNonblockingSSLServerSocket(port, pServerSocketFactory));
179 server.reset(new server::TNonblockingServer(processor, socket));
180 server->setServerEventHandler(listenHandler);
181 server->setNumIOThreads(1);
182 if (userEventBase) {
183 server->registerEvents(userEventBase.get());
184 }
185 server->serve();
186 } catch (const transport::TTransportException&) {
187 if (retry_count > 0) {
188 ++port;
189 startServer(retry_count - 1);
190 } else {
191 throw;
192 }
193 }
194 }
195 };
196
197 struct EventDeleter {
operator ()Fixture::EventDeleter198 void operator()(event_base* p) { event_base_free(p); }
199 };
200
201 protected:
Fixture()202 Fixture() : processor(new test::ParentServiceProcessor(std::make_shared<Handler>())) {}
203
~Fixture()204 ~Fixture() {
205 if (server) {
206 server->stop();
207 }
208 if (thread) {
209 thread->join();
210 }
211 }
212
setEventBase(event_base * user_event_base)213 void setEventBase(event_base* user_event_base) {
214 userEventBase_.reset(user_event_base, EventDeleter());
215 }
216
startServer(int port)217 int startServer(int port) {
218 std::shared_ptr<Runner> runner(new Runner);
219 runner->port = port;
220 runner->processor = processor;
221 runner->userEventBase = userEventBase_;
222
223 std::unique_ptr<apache::thrift::concurrency::ThreadFactory> threadFactory(
224 new apache::thrift::concurrency::ThreadFactory(false));
225 thread = threadFactory->newThread(runner);
226 thread->start();
227 runner->readyBarrier();
228
229 server = runner->server;
230 return runner->port;
231 }
232
canCommunicate(int serverPort)233 bool canCommunicate(int serverPort) {
234 std::shared_ptr<TSSLSocketFactory> pClientSocketFactory = createClientSocketFactory();
235 std::shared_ptr<TSSLSocket> socket = pClientSocketFactory->createSocket("localhost", serverPort);
236 socket->open();
237 test::ParentServiceClient client(std::make_shared<protocol::TBinaryProtocol>(
238 std::make_shared<transport::TFramedTransport>(socket)));
239 client.addString("foo");
240 std::vector<std::string> strings;
241 client.getStrings(strings);
242 return strings.size() == 1 && !(strings[0].compare("foo"));
243 }
244
245 private:
246 std::shared_ptr<event_base> userEventBase_;
247 std::shared_ptr<test::ParentServiceProcessor> processor;
248 protected:
249 std::shared_ptr<server::TNonblockingServer> server;
250 private:
251 std::shared_ptr<apache::thrift::concurrency::Thread> thread;
252
253 };
254
255 BOOST_AUTO_TEST_SUITE(TNonblockingSSLServerTest)
256
BOOST_FIXTURE_TEST_CASE(get_specified_port,Fixture)257 BOOST_FIXTURE_TEST_CASE(get_specified_port, Fixture) {
258 int specified_port = startServer(12345);
259 BOOST_REQUIRE_GE(specified_port, 12345);
260 BOOST_REQUIRE_EQUAL(server->getListenPort(), specified_port);
261 BOOST_CHECK(canCommunicate(specified_port));
262
263 server->stop();
264 }
265
BOOST_FIXTURE_TEST_CASE(get_assigned_port,Fixture)266 BOOST_FIXTURE_TEST_CASE(get_assigned_port, Fixture) {
267 int specified_port = startServer(0);
268 BOOST_REQUIRE_EQUAL(specified_port, 0);
269 int assigned_port = server->getListenPort();
270 BOOST_REQUIRE_NE(assigned_port, 0);
271 BOOST_CHECK(canCommunicate(assigned_port));
272
273 server->stop();
274 }
275
BOOST_FIXTURE_TEST_CASE(provide_event_base,Fixture)276 BOOST_FIXTURE_TEST_CASE(provide_event_base, Fixture) {
277 event_base* eb = event_base_new();
278 setEventBase(eb);
279 startServer(0);
280
281 // assert that the server works
282 BOOST_CHECK(canCommunicate(server->getListenPort()));
283 #if LIBEVENT_VERSION_NUMBER > 0x02010400
284 // also assert that the event_base is actually used when it's easy
285 BOOST_CHECK_GT(event_base_get_num_events(eb, EVENT_BASE_COUNT_ADDED), 0);
286 #endif
287 }
288
289 BOOST_AUTO_TEST_SUITE_END()
290