1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift_test_client;
20 
21 import std.conv;
22 import std.datetime.stopwatch;
23 import std.exception : enforce;
24 import std.getopt;
25 import std.stdio;
26 import std.string;
27 import std.traits;
28 import thrift.base;
29 import thrift.codegen.client;
30 import thrift.protocol.base;
31 import thrift.protocol.binary;
32 import thrift.protocol.compact;
33 import thrift.protocol.json;
34 import thrift.transport.base;
35 import thrift.transport.buffered;
36 import thrift.transport.framed;
37 import thrift.transport.http;
38 import thrift.transport.zlib;
39 import thrift.transport.socket;
40 import thrift.transport.ssl;
41 import thrift.util.hashset;
42 
43 import thrift_test_common;
44 import thrift.test.ThriftTest;
45 import thrift.test.ThriftTest_types;
46 
47 enum TransportType {
48   buffered,
49   framed,
50   http,
51   zlib,
52   raw
53 }
54 
createProtocol(T)55 TProtocol createProtocol(T)(T trans, ProtocolType type) {
56   final switch (type) {
57     case ProtocolType.binary:
58       return tBinaryProtocol(trans);
59     case ProtocolType.compact:
60       return tCompactProtocol(trans);
61     case ProtocolType.json:
62       return tJsonProtocol(trans);
63   }
64 }
65 
main(string[]args)66 void main(string[] args) {
67   string host = "localhost";
68   ushort port = 9090;
69   uint numTests = 1;
70   bool ssl;
71   ProtocolType protocolType;
72   TransportType transportType;
73   bool zlib;
74   bool trace;
75 
76   getopt(args,
77     "numTests|n", &numTests,
78     "protocol", &protocolType,
79     "ssl", &ssl,
80     "transport", &transportType,
81     "zlib", &zlib,
82     "trace", &trace,
83     "port", &port,
84     "host", (string _, string value) {
85       auto parts = split(value, ":");
86       if (parts.length > 1) {
87         // IPv6 addresses can contain colons, so take the last part for the
88         // port.
89         host = join(parts[0 .. $ - 1], ":");
90         port = to!ushort(parts[$ - 1]);
91       } else {
92         host = value;
93       }
94     }
95   );
96   port = to!ushort(port);
97 
98   TSocket socket;
99   if (ssl) {
100     auto sslContext = new TSSLContext();
101     sslContext.ciphers = "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH";
102     sslContext.authenticate = true;
103     sslContext.loadTrustedCertificates("../../../test/keys/CA.pem");
104     socket = new TSSLSocket(sslContext, host, port);
105   } else {
106     socket = new TSocket(host, port);
107   }
108 
109   TTransport transport;
110   final switch (transportType) {
111     case TransportType.buffered:
112       transport = new TBufferedTransport(socket);
113       break;
114     case TransportType.framed:
115       transport = new TFramedTransport(socket);
116       break;
117     case TransportType.http:
118       transport = new TClientHttpTransport(socket, host, "/service");
119       break;
120     case TransportType.zlib:
121       transport = new TZlibTransport(socket);
122       break;
123     case TransportType.raw:
124       transport = socket;
125       break;
126   }
127   if (zlib && transportType != TransportType.zlib) {
128     transport = new TZlibTransport(socket);
129   }
130   TProtocol protocol = createProtocol(transport, protocolType);
131 
132   auto client = tClient!ThriftTest(protocol);
133 
134   ulong time_min;
135   ulong time_max;
136   ulong time_tot;
137 
138   StopWatch sw;
139   foreach(test; 0 .. numTests) {
140     sw.start();
141 
142     protocol.transport.open();
143 
144     if (trace) writefln("Test #%s, connect %s:%s", test + 1, host, port);
145 
146     if (trace) write("testVoid()");
147     client.testVoid();
148     if (trace) writeln(" = void");
149 
150     if (trace) write("testString(\"Test\")");
151     string s = client.testString("Test");
152     if (trace) writefln(" = \"%s\"", s);
153     enforce(s == "Test");
154 
155     if (trace) write("testByte(1)");
156     byte u8 = client.testByte(1);
157     if (trace) writefln(" = %s", u8);
158     enforce(u8 == 1);
159 
160     if (trace) write("testI32(-1)");
161     int i32 = client.testI32(-1);
162     if (trace) writefln(" = %s", i32);
163     enforce(i32 == -1);
164 
165     if (trace) write("testI64(-34359738368)");
166     long i64 = client.testI64(-34359738368L);
167     if (trace) writefln(" = %s", i64);
168     enforce(i64 == -34359738368L);
169 
170     if (trace) write("testDouble(-5.2098523)");
171     double dub = client.testDouble(-5.2098523);
172     if (trace) writefln(" = %s", dub);
173     enforce(dub == -5.2098523);
174 
175 	// TODO: add testBinary() call
176 
177     Xtruct out1;
178     out1.string_thing = "Zero";
179     out1.byte_thing = 1;
180     out1.i32_thing = -3;
181     out1.i64_thing = -5;
182     if (trace) writef("testStruct(%s)", out1);
183     auto in1 = client.testStruct(out1);
184     if (trace) writefln(" = %s", in1);
185     enforce(in1 == out1);
186 
187     if (trace) write("testNest({1, {\"Zero\", 1, -3, -5}), 5}");
188     Xtruct2 out2;
189     out2.byte_thing = 1;
190     out2.struct_thing = out1;
191     out2.i32_thing = 5;
192     auto in2 = client.testNest(out2);
193     in1 = in2.struct_thing;
194     if (trace) writefln(" = {%s, {\"%s\", %s, %s, %s}, %s}", in2.byte_thing,
195       in1.string_thing, in1.byte_thing, in1.i32_thing, in1.i64_thing,
196       in2.i32_thing);
197     enforce(in2 == out2);
198 
199     int[int] mapout;
200     for (int i = 0; i < 5; ++i) {
201       mapout[i] = i - 10;
202     }
203     if (trace) writef("testMap({%s})", mapout);
204     auto mapin = client.testMap(mapout);
205     if (trace) writefln(" = {%s}", mapin);
206     enforce(mapin == mapout);
207 
208     auto setout = new HashSet!int;
209     for (int i = -2; i < 3; ++i) {
210       setout ~= i;
211     }
212     if (trace) writef("testSet(%s)", setout);
213     auto setin = client.testSet(setout);
214     if (trace) writefln(" = %s", setin);
215     enforce(setin == setout);
216 
217     int[] listout;
218     for (int i = -2; i < 3; ++i) {
219       listout ~= i;
220     }
221     if (trace) writef("testList(%s)", listout);
222     auto listin = client.testList(listout);
223     if (trace) writefln(" = %s", listin);
224     enforce(listin == listout);
225 
226     {
227       if (trace) write("testEnum(ONE)");
228       auto ret = client.testEnum(Numberz.ONE);
229       if (trace) writefln(" = %s", ret);
230       enforce(ret == Numberz.ONE);
231 
232       if (trace) write("testEnum(TWO)");
233       ret = client.testEnum(Numberz.TWO);
234       if (trace) writefln(" = %s", ret);
235       enforce(ret == Numberz.TWO);
236 
237       if (trace) write("testEnum(THREE)");
238       ret = client.testEnum(Numberz.THREE);
239       if (trace) writefln(" = %s", ret);
240       enforce(ret == Numberz.THREE);
241 
242       if (trace) write("testEnum(FIVE)");
243       ret = client.testEnum(Numberz.FIVE);
244       if (trace) writefln(" = %s", ret);
245       enforce(ret == Numberz.FIVE);
246 
247       if (trace) write("testEnum(EIGHT)");
248       ret = client.testEnum(Numberz.EIGHT);
249       if (trace) writefln(" = %s", ret);
250       enforce(ret == Numberz.EIGHT);
251     }
252 
253     if (trace) write("testTypedef(309858235082523)");
254     UserId uid = client.testTypedef(309858235082523L);
255     if (trace) writefln(" = %s", uid);
256     enforce(uid == 309858235082523L);
257 
258     if (trace) write("testMapMap(1)");
259     auto mm = client.testMapMap(1);
260     if (trace) writefln(" = {%s}", mm);
261     // Simply doing == doesn't seem to work for nested AAs.
262     foreach (key, value; mm) {
263       enforce(testMapMapReturn[key] == value);
264     }
265     foreach (key, value; testMapMapReturn) {
266       enforce(mm[key] == value);
267     }
268 
269     Insanity insane;
270     insane.userMap[Numberz.FIVE] = 5000;
271     Xtruct truck;
272     truck.string_thing = "Truck";
273     truck.byte_thing = 8;
274     truck.i32_thing = 8;
275     truck.i64_thing = 8;
276     insane.xtructs ~= truck;
277     if (trace) write("testInsanity()");
278     auto whoa = client.testInsanity(insane);
279     if (trace) writefln(" = %s", whoa);
280 
281     // Commented for now, this is cumbersome to write without opEqual getting
282     // called on AA comparison.
283     // enforce(whoa == testInsanityReturn);
284 
285     {
286       try {
287         if (trace) write("client.testException(\"Xception\") =>");
288         client.testException("Xception");
289         if (trace) writeln("  void\nFAILURE");
290         throw new Exception("testException failed.");
291       } catch (Xception e) {
292         if (trace) writefln("  {%s, \"%s\"}", e.errorCode, e.message);
293       }
294 
295       try {
296         if (trace) write("client.testException(\"TException\") =>");
297         client.testException("Xception");
298         if (trace) writeln("  void\nFAILURE");
299         throw new Exception("testException failed.");
300       } catch (TException e) {
301         if (trace) writefln("  {%s}", e.msg);
302       }
303 
304       try {
305         if (trace) write("client.testException(\"success\") =>");
306         client.testException("success");
307         if (trace) writeln("  void");
308       } catch (Exception e) {
309         if (trace) writeln("  exception\nFAILURE");
310         throw new Exception("testException failed.");
311       }
312     }
313 
314     {
315       try {
316         if (trace) write("client.testMultiException(\"Xception\", \"test 1\") =>");
317         auto result = client.testMultiException("Xception", "test 1");
318         if (trace) writeln("  result\nFAILURE");
319         throw new Exception("testMultiException failed.");
320       } catch (Xception e) {
321         if (trace) writefln("  {%s, \"%s\"}", e.errorCode, e.message);
322       }
323 
324       try {
325         if (trace) write("client.testMultiException(\"Xception2\", \"test 2\") =>");
326         auto result = client.testMultiException("Xception2", "test 2");
327         if (trace) writeln("  result\nFAILURE");
328         throw new Exception("testMultiException failed.");
329       } catch (Xception2 e) {
330         if (trace) writefln("  {%s, {\"%s\"}}",
331           e.errorCode, e.struct_thing.string_thing);
332       }
333 
334       try {
335         if (trace) writef("client.testMultiException(\"success\", \"test 3\") =>");
336         auto result = client.testMultiException("success", "test 3");
337         if (trace) writefln("  {{\"%s\"}}", result.string_thing);
338       } catch (Exception e) {
339         if (trace) writeln("  exception\nFAILURE");
340         throw new Exception("testMultiException failed.");
341       }
342     }
343 
344     // Do not run oneway test when doing multiple iterations, as it blocks the
345     // server for three seconds.
346     if (numTests == 1) {
347       if (trace) writef("client.testOneway(3) =>");
348       auto onewayWatch = StopWatch(AutoStart.yes);
349       client.testOneway(3);
350       onewayWatch.stop();
351       if (onewayWatch.peek.total!"msecs" > 200) {
352         if (trace) {
353           writefln("  FAILURE - took %s ms", onewayWatch.peek.total!"usecs" / 1000.0);
354         }
355         throw new Exception("testOneway failed.");
356       } else {
357         if (trace) {
358           writefln("  success - took %s ms", onewayWatch.peek.total!"usecs"  / 1000.0);
359         }
360       }
361 
362       // Redo a simple test after the oneway to make sure we aren't "off by
363       // one", which would be the case if the server treated oneway methods
364       // like normal ones.
365       if (trace) write("re-test testI32(-1)");
366       i32 = client.testI32(-1);
367       if (trace) writefln(" = %s", i32);
368     }
369 
370     // Time metering.
371     sw.stop();
372 
373     immutable tot = sw.peek.total!"usecs" ;
374     if (trace) writefln("Total time: %s us\n", tot);
375 
376     time_tot += tot;
377     if (time_min == 0 || tot < time_min) {
378       time_min = tot;
379     }
380     if (tot > time_max) {
381       time_max = tot;
382     }
383     protocol.transport.close();
384 
385     sw.reset();
386   }
387 
388   writeln("All tests done.");
389 
390   if (numTests > 1) {
391     auto time_avg = time_tot / numTests;
392     writefln("Min time: %s us", time_min);
393     writefln("Max time: %s us", time_max);
394     writefln("Avg time: %s us", time_avg);
395   }
396 }
397