1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 package org.apache.thrift.protocol;
20 
21 import static org.junit.jupiter.api.Assertions.assertEquals;
22 import static org.junit.jupiter.api.Assertions.assertThrows;
23 
24 import java.nio.ByteBuffer;
25 import java.util.Arrays;
26 import java.util.HashMap;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Set;
30 import java.util.UUID;
31 import java.util.stream.Collectors;
32 import java.util.stream.Stream;
33 import org.apache.thrift.Fixtures;
34 import org.apache.thrift.TBase;
35 import org.apache.thrift.TConfiguration;
36 import org.apache.thrift.TDeserializer;
37 import org.apache.thrift.TException;
38 import org.apache.thrift.TSerializer;
39 import org.apache.thrift.server.ServerTestBase;
40 import org.apache.thrift.transport.TMemoryBuffer;
41 import org.apache.thrift.transport.TTransportException;
42 import org.junit.jupiter.api.Test;
43 import thrift.test.CompactProtoTestStruct;
44 import thrift.test.HolyMoley;
45 import thrift.test.Nesting;
46 import thrift.test.OneOfEach;
47 import thrift.test.Srv;
48 import thrift.test.ThriftTest;
49 
50 public abstract class ProtocolTestBase {
51 
52   /** Does it make sense to call methods like writeI32 directly on your protocol? */
canBeUsedNaked()53   protected abstract boolean canBeUsedNaked();
54 
55   /** The protocol factory for the protocol being tested. */
getFactory()56   protected abstract TProtocolFactory getFactory();
57 
58   @Test
testDouble()59   public void testDouble() throws Exception {
60     if (canBeUsedNaked()) {
61       TMemoryBuffer buf = new TMemoryBuffer(1000);
62       TProtocol proto = getFactory().getProtocol(buf);
63       proto.writeDouble(123.456);
64       assertEquals(123.456, proto.readDouble());
65     }
66 
67     internalTestStructField(
68         new StructFieldTestCase(TType.DOUBLE, (short) 15) {
69           @Override
70           public void readMethod(TProtocol proto) throws TException {
71             assertEquals(123.456, proto.readDouble());
72           }
73 
74           @Override
75           public void writeMethod(TProtocol proto) throws TException {
76             proto.writeDouble(123.456);
77           }
78         });
79   }
80 
81   @Test
testSerialization()82   public void testSerialization() throws Exception {
83     internalTestSerialization(OneOfEach.class, Fixtures.getOneOfEach());
84     internalTestSerialization(Nesting.class, Fixtures.getNesting());
85     internalTestSerialization(HolyMoley.class, Fixtures.getHolyMoley());
86     internalTestSerialization(CompactProtoTestStruct.class, Fixtures.getCompactProtoTestStruct());
87   }
88 
89   @Test
testBinary()90   public void testBinary() throws Exception {
91     for (byte[] b :
92         Arrays.asList(
93             new byte[0],
94             new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
95             new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
96             new byte[] {0x5D},
97             new byte[] {(byte) 0xD5, (byte) 0x5D},
98             new byte[] {(byte) 0xFF, (byte) 0xD5, (byte) 0x5D},
99             new byte[128])) {
100       if (canBeUsedNaked()) {
101         internalTestNakedBinary(b);
102       }
103       internalTestBinaryField(b);
104     }
105 
106     if (canBeUsedNaked()) {
107       byte[] data = {1, 2, 3, 4, 5, 6};
108 
109       TMemoryBuffer buf = new TMemoryBuffer(0);
110       TProtocol proto = getFactory().getProtocol(buf);
111       ByteBuffer bb = ByteBuffer.wrap(data);
112       bb.get();
113       proto.writeBinary(bb.slice());
114       assertEquals(ByteBuffer.wrap(data, 1, 5), proto.readBinary());
115     }
116   }
117 
118   @Test
testString()119   public void testString() throws Exception {
120     for (String s :
121         Arrays.asList("", "short", "borderlinetiny", "a bit longer than the smallest possible")) {
122       if (canBeUsedNaked()) {
123         internalTestNakedString(s);
124       }
125       internalTestStringField(s);
126     }
127   }
128 
129   @Test
testUuid()130   public void testUuid() throws Exception {
131     UUID uuid = UUID.fromString("00112233-4455-6677-8899-aabbccddeeff");
132     if (canBeUsedNaked()) {
133       internalTestNakedUuid(uuid);
134     }
135     internalTestUuidField(uuid);
136   }
137 
138   @Test
testLong()139   public void testLong() throws Exception {
140     if (canBeUsedNaked()) {
141       internalTestNakedI64(0);
142     }
143     internalTestI64Field(0);
144     for (int i = 0; i < 62; i++) {
145       if (canBeUsedNaked()) {
146         internalTestNakedI64(1L << i);
147         internalTestNakedI64(-(1L << i));
148       }
149       internalTestI64Field(1L << i);
150       internalTestI64Field(-(1L << i));
151     }
152   }
153 
154   @Test
testInt()155   public void testInt() throws Exception {
156     for (int i :
157         Arrays.asList(
158             0, 1, 7, 150, 15000, 31337, 0xffff, 0xffffff, -1, -7, -150, -15000, -0xffff,
159             -0xffffff)) {
160       if (canBeUsedNaked()) {
161         internalTestNakedI32(i);
162       }
163       internalTestI32Field(i);
164     }
165   }
166 
167   @Test
testShort()168   public void testShort() throws Exception {
169     for (int s : Arrays.asList(0, 1, 7, 150, 15000, 0x7fff, -1, -7, -150, -15000, -0x7fff)) {
170       if (canBeUsedNaked()) {
171         internalTestNakedI16((short) s);
172       }
173       internalTestI16Field((short) s);
174     }
175   }
176 
177   @Test
testByte()178   public void testByte() throws Exception {
179     if (canBeUsedNaked()) {
180       internalTestNakedByte();
181     }
182     for (int i = 0; i < 128; i++) {
183       internalTestByteField((byte) i);
184       internalTestByteField((byte) -i);
185     }
186   }
187 
internalTestNakedByte()188   private void internalTestNakedByte() throws Exception {
189     TMemoryBuffer buf = new TMemoryBuffer(1000);
190     TProtocol proto = getFactory().getProtocol(buf);
191     proto.writeByte((byte) 123);
192     assertEquals((byte) 123, proto.readByte());
193   }
194 
internalTestByteField(final byte b)195   private void internalTestByteField(final byte b) throws Exception {
196     internalTestStructField(
197         new StructFieldTestCase(TType.BYTE, (short) 15) {
198           public void writeMethod(TProtocol proto) throws TException {
199             proto.writeByte(b);
200           }
201 
202           public void readMethod(TProtocol proto) throws TException {
203             assertEquals(b, proto.readByte());
204           }
205         });
206   }
207 
internalTestNakedI16(short n)208   private void internalTestNakedI16(short n) throws Exception {
209     TMemoryBuffer buf = new TMemoryBuffer(0);
210     TProtocol proto = getFactory().getProtocol(buf);
211     proto.writeI16(n);
212     assertEquals(n, proto.readI16());
213   }
214 
internalTestI16Field(final short n)215   private void internalTestI16Field(final short n) throws Exception {
216     internalTestStructField(
217         new StructFieldTestCase(TType.I16, (short) 15) {
218           public void writeMethod(TProtocol proto) throws TException {
219             proto.writeI16(n);
220           }
221 
222           public void readMethod(TProtocol proto) throws TException {
223             assertEquals(n, proto.readI16());
224           }
225         });
226   }
227 
internalTestNakedUuid(UUID uuid)228   private void internalTestNakedUuid(UUID uuid) throws TException {
229     TMemoryBuffer buf = new TMemoryBuffer(0);
230     TProtocol protocol = getFactory().getProtocol(buf);
231     protocol.writeUuid(uuid);
232     assertEquals(uuid, protocol.readUuid());
233   }
234 
internalTestUuidField(UUID uuid)235   private void internalTestUuidField(UUID uuid) throws Exception {
236     internalTestStructField(
237         new StructFieldTestCase(TType.UUID, (short) 17) {
238           @Override
239           public void writeMethod(TProtocol proto) throws TException {
240             proto.writeUuid(uuid);
241           }
242 
243           @Override
244           public void readMethod(TProtocol proto) throws TException {
245             assertEquals(uuid, proto.readUuid());
246           }
247         });
248   }
249 
internalTestNakedI32(int n)250   private void internalTestNakedI32(int n) throws Exception {
251     TMemoryBuffer buf = new TMemoryBuffer(0);
252     TProtocol proto = getFactory().getProtocol(buf);
253     proto.writeI32(n);
254     assertEquals(n, proto.readI32());
255   }
256 
internalTestI32Field(final int n)257   private void internalTestI32Field(final int n) throws Exception {
258     internalTestStructField(
259         new StructFieldTestCase(TType.I32, (short) 15) {
260           public void writeMethod(TProtocol proto) throws TException {
261             proto.writeI32(n);
262           }
263 
264           public void readMethod(TProtocol proto) throws TException {
265             assertEquals(n, proto.readI32());
266           }
267         });
268   }
269 
internalTestNakedI64(long n)270   private void internalTestNakedI64(long n) throws Exception {
271     TMemoryBuffer buf = new TMemoryBuffer(0);
272     TProtocol proto = getFactory().getProtocol(buf);
273     proto.writeI64(n);
274     assertEquals(n, proto.readI64());
275   }
276 
internalTestI64Field(final long n)277   private void internalTestI64Field(final long n) throws Exception {
278     internalTestStructField(
279         new StructFieldTestCase(TType.I64, (short) 15) {
280           public void writeMethod(TProtocol proto) throws TException {
281             proto.writeI64(n);
282           }
283 
284           public void readMethod(TProtocol proto) throws TException {
285             assertEquals(n, proto.readI64());
286           }
287         });
288   }
289 
internalTestNakedString(String str)290   private void internalTestNakedString(String str) throws Exception {
291     TMemoryBuffer buf = new TMemoryBuffer(0);
292     TProtocol proto = getFactory().getProtocol(buf);
293     proto.writeString(str);
294     assertEquals(str, proto.readString());
295   }
296 
internalTestStringField(final String str)297   private void internalTestStringField(final String str) throws Exception {
298     internalTestStructField(
299         new StructFieldTestCase(TType.STRING, (short) 15) {
300           public void writeMethod(TProtocol proto) throws TException {
301             proto.writeString(str);
302           }
303 
304           public void readMethod(TProtocol proto) throws TException {
305             assertEquals(str, proto.readString());
306           }
307         });
308   }
309 
internalTestNakedBinary(byte[] data)310   private void internalTestNakedBinary(byte[] data) throws Exception {
311     TMemoryBuffer buf = new TMemoryBuffer(0);
312     TProtocol proto = getFactory().getProtocol(buf);
313     proto.writeBinary(ByteBuffer.wrap(data));
314     assertEquals(ByteBuffer.wrap(data), proto.readBinary());
315   }
316 
internalTestBinaryField(final byte[] data)317   private void internalTestBinaryField(final byte[] data) throws Exception {
318     internalTestStructField(
319         new StructFieldTestCase(TType.STRING, (short) 15) {
320           public void writeMethod(TProtocol proto) throws TException {
321             proto.writeBinary(ByteBuffer.wrap(data));
322           }
323 
324           public void readMethod(TProtocol proto) throws TException {
325             assertEquals(ByteBuffer.wrap(data), proto.readBinary());
326           }
327         });
328   }
329 
internalTestSerialization(Class<T> klass, T expected)330   private <T extends TBase> void internalTestSerialization(Class<T> klass, T expected)
331       throws Exception {
332     TMemoryBuffer buf = new TMemoryBuffer(0);
333     TBinaryProtocol binproto = new TBinaryProtocol(buf);
334 
335     expected.write(binproto);
336 
337     buf = new TMemoryBuffer(0);
338     TProtocol proto = getFactory().getProtocol(buf);
339 
340     expected.write(proto);
341     System.out.println("Size in " + proto.getClass().getSimpleName() + ": " + buf.length());
342 
343     T actual = klass.getDeclaredConstructor().newInstance();
344     actual.read(proto);
345     assertEquals(expected, actual);
346   }
347 
348   @Test
testMessage()349   public void testMessage() throws Exception {
350     List<TMessage> msgs =
351         Arrays.asList(
352             new TMessage[] {
353               new TMessage("short message name", TMessageType.CALL, 0),
354               new TMessage("1", TMessageType.REPLY, 12345),
355               new TMessage(
356                   "loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
357               new TMessage("Janky", TMessageType.CALL, 0),
358             });
359 
360     for (TMessage msg : msgs) {
361       TMemoryBuffer buf = new TMemoryBuffer(0);
362       TProtocol proto = getFactory().getProtocol(buf);
363       TMessage output = null;
364 
365       proto.writeMessageBegin(msg);
366       proto.writeMessageEnd();
367 
368       output = proto.readMessageBegin();
369 
370       assertEquals(msg, output);
371     }
372   }
373 
374   @Test
testServerRequest()375   public void testServerRequest() throws Exception {
376     Srv.Iface handler =
377         new Srv.Iface() {
378           public int Janky(int i32arg) throws TException {
379             return i32arg * 2;
380           }
381 
382           public int primitiveMethod() throws TException {
383             return 0;
384           }
385 
386           public CompactProtoTestStruct structMethod() throws TException {
387             return null;
388           }
389 
390           public void voidMethod() throws TException {}
391 
392           public void methodWithDefaultArgs(int something) throws TException {}
393 
394           @Override
395           public void onewayMethod() throws TException {}
396 
397           @Override
398           public boolean declaredExceptionMethod(boolean shouldThrow) throws TException {
399             return shouldThrow;
400           }
401         };
402 
403     Srv.Processor testProcessor = new Srv.Processor(handler);
404 
405     TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
406     TProtocol clientOutProto = getFactory().getProtocol(clientOutTrans);
407     TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
408     TProtocol clientInProto = getFactory().getProtocol(clientInTrans);
409 
410     Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);
411 
412     testClient.send_Janky(1);
413     // System.out.println(clientOutTrans.inspect());
414     testProcessor.process(clientOutProto, clientInProto);
415     // System.out.println(clientInTrans.inspect());
416     assertEquals(2, testClient.recv_Janky());
417   }
418 
419   @Test
testTDeserializer()420   public void testTDeserializer() throws TException {
421     TSerializer ser = new TSerializer(getFactory());
422     byte[] bytes = ser.serialize(Fixtures.getCompactProtoTestStruct());
423 
424     TDeserializer deser = new TDeserializer(getFactory());
425     CompactProtoTestStruct cpts = new CompactProtoTestStruct();
426     deser.deserialize(cpts, bytes);
427 
428     assertEquals(Fixtures.getCompactProtoTestStruct(), cpts);
429   }
430 
431   //
432   // Helper methods
433   //
434 
internalTestStructField(StructFieldTestCase testCase)435   private void internalTestStructField(StructFieldTestCase testCase) throws Exception {
436     TMemoryBuffer buf = new TMemoryBuffer(0);
437     TProtocol proto = getFactory().getProtocol(buf);
438 
439     TField field = new TField("test_field", testCase.type_, testCase.id_);
440     proto.writeStructBegin(new TStruct("test_struct"));
441     proto.writeFieldBegin(field);
442     testCase.writeMethod(proto);
443     proto.writeFieldEnd();
444     proto.writeStructEnd();
445 
446     proto.readStructBegin();
447     TField readField = proto.readFieldBegin();
448     assertEquals(testCase.id_, readField.id);
449     assertEquals(testCase.type_, readField.type);
450     testCase.readMethod(proto);
451     proto.readStructEnd();
452   }
453 
454   private abstract static class StructFieldTestCase {
455     byte type_;
456     short id_;
457 
StructFieldTestCase(byte type, short id)458     public StructFieldTestCase(byte type, short id) {
459       type_ = type;
460       id_ = id;
461     }
462 
writeMethod(TProtocol proto)463     public abstract void writeMethod(TProtocol proto) throws TException;
464 
readMethod(TProtocol proto)465     public abstract void readMethod(TProtocol proto) throws TException;
466   }
467 
468   private static final int NUM_TRIALS = 5;
469   private static final int NUM_REPS = 10000;
470 
benchmark()471   protected void benchmark() throws Exception {
472     for (int trial = 0; trial < NUM_TRIALS; trial++) {
473       TSerializer ser = new TSerializer(getFactory());
474       byte[] serialized = null;
475       long serStart = System.currentTimeMillis();
476       for (int rep = 0; rep < NUM_REPS; rep++) {
477         serialized = ser.serialize(Fixtures.getHolyMoley());
478       }
479       long serEnd = System.currentTimeMillis();
480       long serElapsed = serEnd - serStart;
481       System.out.println(
482           "Ser:\t"
483               + serElapsed
484               + "ms\t"
485               + ((double) serElapsed / NUM_REPS)
486               + "ms per serialization");
487 
488       HolyMoley cpts = new HolyMoley();
489       TDeserializer deser = new TDeserializer(getFactory());
490       long deserStart = System.currentTimeMillis();
491       for (int rep = 0; rep < NUM_REPS; rep++) {
492         deser.deserialize(cpts, serialized);
493       }
494       long deserEnd = System.currentTimeMillis();
495       long deserElapsed = deserEnd - deserStart;
496       System.out.println(
497           "Des:\t"
498               + deserElapsed
499               + "ms\t"
500               + ((double) deserElapsed / NUM_REPS)
501               + "ms per deserialization");
502     }
503   }
504 
505   private final ServerTestBase.TestHandler testHandler =
506       new ServerTestBase.TestHandler() {
507         @Override
508         public String testString(String thing) {
509           thing = thing + " Apache Thrift Java " + thing;
510           return thing;
511         }
512 
513         @Override
514         public List<Integer> testList(List<Integer> thing) {
515           thing.addAll(thing);
516           thing.addAll(thing);
517           return thing;
518         }
519 
520         @Override
521         public Set<Integer> testSet(Set<Integer> thing) {
522           thing.addAll(thing.stream().map(x -> x + 100).collect(Collectors.toSet()));
523           return thing;
524         }
525 
526         @Override
527         public Map<String, String> testStringMap(Map<String, String> thing) {
528           thing.put("a", "123");
529           thing.put(" x y ", " with spaces ");
530           thing.put("same", "same");
531           thing.put("0", "numeric key");
532           thing.put("1", "");
533           thing.put("ok", "2355555");
534           thing.put("end", "0");
535           return thing;
536         }
537       };
538 
initConfig(int maxSize)539   private TProtocol initConfig(int maxSize) throws TException {
540     TConfiguration config = TConfiguration.custom().setMaxMessageSize(maxSize).build();
541     TMemoryBuffer bufferTrans = new TMemoryBuffer(config, 0);
542     return getFactory().getProtocol(bufferTrans);
543   }
544 
545   @Test
testReadCheckMaxMessageRequestForString()546   public void testReadCheckMaxMessageRequestForString() throws TException {
547     TProtocol clientOutProto = initConfig(15);
548     TProtocol clientInProto = initConfig(15);
549     ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
550     ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
551     try {
552       testClient.send_testString("test");
553       testProcessor.process(clientOutProto, clientInProto);
554       String result = testClient.recv_testString();
555       System.out.println("----result: " + result);
556     } catch (TException e) {
557       assertEquals("MaxMessageSize reached", e.getMessage());
558     }
559   }
560 
561   @Test
testReadCheckMaxMessageRequestForList()562   public void testReadCheckMaxMessageRequestForList() throws TException {
563     TProtocol clientOutProto = initConfig(15);
564     TProtocol clientInProto = initConfig(15);
565     ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
566     ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
567     TTransportException e =
568         assertThrows(
569             TTransportException.class,
570             () -> {
571               testClient.send_testList(Arrays.asList(1, 23242346, 888888, 90));
572               testProcessor.process(clientOutProto, clientInProto);
573               testClient.recv_testList();
574             },
575             "Limitations not achieved as expected");
576     assertEquals("MaxMessageSize reached", e.getMessage());
577   }
578 
579   @Test
testReadCheckMaxMessageRequestForMap()580   public void testReadCheckMaxMessageRequestForMap() throws TException {
581     TProtocol clientOutProto = initConfig(13);
582     TProtocol clientInProto = initConfig(13);
583     ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
584     ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
585     Map<String, String> thing = new HashMap<>();
586     thing.put("key", "Thrift");
587 
588     TTransportException e =
589         assertThrows(
590             TTransportException.class,
591             () -> {
592               testClient.send_testStringMap(thing);
593               testProcessor.process(clientOutProto, clientInProto);
594               testClient.recv_testStringMap();
595             },
596             "Limitations not achieved as expected");
597 
598     assertEquals("MaxMessageSize reached", e.getMessage());
599   }
600 
601   @Test
testReadCheckMaxMessageRequestForSet()602   public void testReadCheckMaxMessageRequestForSet() throws TException {
603     TProtocol clientOutProto = initConfig(10);
604     TProtocol clientInProto = initConfig(10);
605     ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
606     ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
607     TTransportException e =
608         assertThrows(
609             TTransportException.class,
610             () -> {
611               testClient.send_testSet(
612                   Stream.of(234, 0, 987087, 45, 88888888, 9).collect(Collectors.toSet()));
613               testProcessor.process(clientOutProto, clientInProto);
614               testClient.recv_testSet();
615             },
616             "Limitations not achieved as expected");
617     assertEquals("MaxMessageSize reached", e.getMessage());
618   }
619 }
620