1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 module thrift_test_server;
21 
22 import core.stdc.errno : errno;
23 import core.stdc.signal : signal, SIGINT, SIG_DFL, SIG_ERR;
24 import core.thread : dur, Thread;
25 import std.algorithm;
26 import std.exception : enforce;
27 import std.getopt;
28 import std.parallelism : totalCPUs;
29 import std.string;
30 import std.stdio;
31 import std.typetuple : TypeTuple, staticMap;
32 import thrift.base;
33 import thrift.codegen.processor;
34 import thrift.protocol.base;
35 import thrift.protocol.binary;
36 import thrift.protocol.compact;
37 import thrift.protocol.json;
38 import thrift.server.base;
39 import thrift.server.transport.socket;
40 import thrift.server.transport.ssl;
41 import thrift.transport.base;
42 import thrift.transport.buffered;
43 import thrift.transport.framed;
44 import thrift.transport.http;
45 import thrift.transport.zlib;
46 import thrift.transport.ssl;
47 import thrift.util.cancellation;
48 import thrift.util.hashset;
49 import test_utils;
50 
51 import thrift_test_common;
52 import thrift.test.ThriftTest_types;
53 import thrift.test.ThriftTest;
54 
55 class TestHandler : ThriftTest {
this(bool trace)56   this(bool trace) {
57     trace_ = trace;
58   }
59 
testVoid()60   override void testVoid() {
61     if (trace_) writeln("testVoid()");
62   }
63 
testString(string thing)64   override string testString(string thing) {
65     if (trace_) writefln("testString(\"%s\")", thing);
66     return thing;
67   }
68 
testByte(byte thing)69   override byte testByte(byte thing) {
70     if (trace_) writefln("testByte(%s)", thing);
71     return thing;
72   }
73 
testI32(int thing)74   override int testI32(int thing) {
75     if (trace_) writefln("testI32(%s)", thing);
76     return thing;
77   }
78 
testI64(long thing)79   override long testI64(long thing) {
80     if (trace_) writefln("testI64(%s)", thing);
81     return thing;
82   }
83 
testDouble(double thing)84   override double testDouble(double thing) {
85     if (trace_) writefln("testDouble(%s)", thing);
86     return thing;
87   }
88 
testBinary(string thing)89   override string testBinary(string thing) {
90     if (trace_) writefln("testBinary(\"%s\")", thing);
91     return thing;
92   }
93 
testBool(bool thing)94   override bool testBool(bool thing) {
95     if (trace_) writefln("testBool(\"%s\")", thing);
96     return thing;
97   }
98 
testStruct(ref const (Xtruct)thing)99   override Xtruct testStruct(ref const(Xtruct) thing) {
100     if (trace_) writefln("testStruct({\"%s\", %s, %s, %s})",
101       thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing);
102     return thing;
103   }
104 
testNest(ref const (Xtruct2)nest)105   override Xtruct2 testNest(ref const(Xtruct2) nest) {
106     auto thing = nest.struct_thing;
107     if (trace_) writefln("testNest({%s, {\"%s\", %s, %s, %s}, %s})",
108       nest.byte_thing, thing.string_thing, thing.byte_thing, thing.i32_thing,
109       thing.i64_thing, nest.i32_thing);
110     return nest;
111   }
112 
testMap(int[int]thing)113   override int[int] testMap(int[int] thing) {
114     if (trace_) writefln("testMap({%s})", thing);
115     return thing;
116   }
117 
118   override HashSet!int testSet(HashSet!int thing) {
119     if (trace_) writefln("testSet({%s})",
120       join(map!`to!string(a)`(thing[]), ", "));
121     return thing;
122   }
123 
testList(int[]thing)124   override int[] testList(int[] thing) {
125     if (trace_) writefln("testList(%s)", thing);
126     return thing;
127   }
128 
testEnum(Numberz thing)129   override Numberz testEnum(Numberz thing) {
130     if (trace_) writefln("testEnum(%s)", thing);
131     return thing;
132   }
133 
testTypedef(UserId thing)134   override UserId testTypedef(UserId thing) {
135     if (trace_) writefln("testTypedef(%s)", thing);
136     return thing;
137   }
138 
testStringMap(string[string]thing)139   override string[string] testStringMap(string[string] thing) {
140     if (trace_) writefln("testStringMap(%s)", thing);
141     return thing;
142   }
143 
testMapMap(int hello)144   override int[int][int] testMapMap(int hello) {
145     if (trace_) writefln("testMapMap(%s)", hello);
146     return testMapMapReturn;
147   }
148 
testInsanity(ref const (Insanity)argument)149   override Insanity[Numberz][UserId] testInsanity(ref const(Insanity) argument) {
150     if (trace_) writeln("testInsanity()");
151     Insanity[Numberz][UserId] ret;
152     Insanity[Numberz] m1;
153     Insanity[Numberz] m2;
154     Insanity tmp;
155     tmp = cast(Insanity)argument;
156     m1[Numberz.TWO] = tmp;
157     m1[Numberz.THREE] = tmp;
158     m2[Numberz.SIX] = Insanity();
159     ret[1] = m1;
160     ret[2] = m2;
161     return ret;
162   }
163 
testMulti(byte arg0,int arg1,long arg2,string[short]arg3,Numberz arg4,UserId arg5)164   override Xtruct testMulti(byte arg0, int arg1, long arg2, string[short] arg3,
165     Numberz arg4, UserId arg5)
166   {
167     if (trace_) writeln("testMulti()");
168     return Xtruct("Hello2", arg0, arg1, arg2);
169   }
170 
testException(string arg)171   override void testException(string arg) {
172     if (trace_) writefln("testException(%s)", arg);
173     if (arg == "Xception") {
174       auto e = new Xception();
175       e.errorCode = 1001;
176       e.message = arg;
177       throw e;
178     } else if (arg == "TException") {
179       throw new TException();
180     } else if (arg == "ApplicationException") {
181       throw new TException();
182     }
183   }
184 
testMultiException(string arg0,string arg1)185   override Xtruct testMultiException(string arg0, string arg1) {
186     if (trace_) writefln("testMultiException(%s, %s)", arg0, arg1);
187 
188     if (arg0 == "Xception") {
189       auto e = new Xception();
190       e.errorCode = 1001;
191       e.message = "This is an Xception";
192       throw e;
193     } else if (arg0 == "Xception2") {
194       auto e = new Xception2();
195       e.errorCode = 2002;
196       e.struct_thing.string_thing = "This is an Xception2";
197       throw e;
198     } else {
199       return Xtruct(arg1);
200     }
201   }
202 
testOneway(int sleepFor)203   override void testOneway(int sleepFor) {
204     if (trace_) writefln("testOneway(%s): Sleeping...", sleepFor);
205     Thread.sleep(dur!"seconds"(sleepFor));
206     if (trace_) writefln("testOneway(%s): done sleeping!", sleepFor);
207   }
208 
209 private:
210   bool trace_;
211 }
212 
213 shared(bool) gShutdown = false;
214 
nogc(C)215 nothrow @nogc extern(C) void handleSignal(int sig) {
216   gShutdown = true;
217 }
218 
219 // Runs a thread that waits for shutdown to be
220 // signaled and then triggers cancellation,
221 // causing the server to stop.  While we could
222 // use a signalfd for this purpose, we are instead
223 // opting for a busy waiting scheme for maximum
224 // portability since signalfd is a linux thing.
225 
226 class ShutdownThread : Thread {
this(TCancellationOrigin cancellation)227   this(TCancellationOrigin cancellation) {
228     cancellation_ = cancellation;
229     super(&run);
230   }
231 
232 private:
run()233   void run() {
234     while (!gShutdown) {
235       Thread.sleep(dur!("msecs")(25));
236     }
237     cancellation_.trigger();
238   }
239 
240   TCancellationOrigin cancellation_;
241 }
242 
main(string[]args)243 void main(string[] args) {
244   ushort port = 9090;
245   ServerType serverType;
246   ProtocolType protocolType;
247   size_t numIOThreads = 1;
248   TransportType transportType;
249   bool ssl = false;
250   bool zlib = false;
251   bool trace = true;
252   size_t taskPoolSize = totalCPUs;
253 
254   getopt(args, "port", &port, "protocol", &protocolType, "server-type",
255     &serverType, "ssl", &ssl, "zlib", &zlib, "num-io-threads", &numIOThreads,
256     "task-pool-size", &taskPoolSize, "trace", &trace,
257     "transport", &transportType);
258 
259   if (serverType == ServerType.nonblocking ||
260     serverType == ServerType.pooledNonblocking
261   ) {
262     enforce(transportType == TransportType.framed,
263       "Need to use framed transport with non-blocking server.");
264     enforce(!ssl, "The non-blocking server does not support SSL yet.");
265 
266     // Don't wrap the contents into another layer of framing.
267     transportType = TransportType.raw;
268   }
269 
270   version (ThriftTestTemplates) {
271     // Only exercise the specialized template code paths if explicitly enabled
272     // to reduce memory consumption on regular test suite runs – there should
273     // not be much that can go wrong with that specifically anyway.
274     alias TypeTuple!(TBufferedTransport, TFramedTransport, TServerHttpTransport)
275       AvailableTransports;
276     alias TypeTuple!(
277       staticMap!(TBinaryProtocol, AvailableTransports),
278       staticMap!(TCompactProtocol, AvailableTransports)
279     ) AvailableProtocols;
280   } else {
281     alias TypeTuple!() AvailableTransports;
282     alias TypeTuple!() AvailableProtocols;
283   }
284 
285   TProtocolFactory protocolFactory;
286   final switch (protocolType) {
287     case ProtocolType.binary:
288       protocolFactory = new TBinaryProtocolFactory!AvailableTransports;
289       break;
290     case ProtocolType.compact:
291       protocolFactory = new TCompactProtocolFactory!AvailableTransports;
292       break;
293     case ProtocolType.json:
294       protocolFactory = new TJsonProtocolFactory!AvailableTransports;
295       break;
296   }
297 
298   auto processor = new TServiceProcessor!(ThriftTest, AvailableProtocols)(
299     new TestHandler(trace));
300 
301   TServerSocket serverSocket;
302   if (ssl) {
303     auto sslContext = new TSSLContext();
304     sslContext.serverSide = true;
305     sslContext.loadCertificate("../../../test/keys/server.crt");
306     sslContext.loadPrivateKey("../../../test/keys/server.key");
307     sslContext.ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
308     serverSocket = new TSSLServerSocket(port, sslContext);
309   } else {
310     serverSocket = new TServerSocket(port);
311   }
312 
313   auto transportFactory = createTransportFactory(transportType);
314 
315   auto server = createServer(serverType, numIOThreads, taskPoolSize,
316     processor, serverSocket, transportFactory, protocolFactory);
317 
318   // Set up SIGINT signal handling
319   enforce(signal(SIGINT, &handleSignal) != SIG_ERR,
320     "Could not replace the SIGINT signal handler: errno {0}".format(errno()));
321 
322   // Set up a server cancellation trigger
323   auto cancel = new TCancellationOrigin();
324 
325   // Set up a listener for the shutdown condition - this will
326   // wake up when the signal occurs and trigger cancellation.
327   auto shutdown = new ShutdownThread(cancel);
328   shutdown.start();
329 
330   // Serve from this thread; the signal will stop the server
331   // and control will return here
332   writefln("Starting %s/%s %s ThriftTest server %son port %s...", protocolType,
333     transportType, serverType, ssl ? "(using SSL) ": "", port);
334   server.serve(cancel);
335   shutdown.join();
336   signal(SIGINT, SIG_DFL);
337 
338   writeln("done.");
339 }
340