1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 package org.apache.thrift.transport;
21 
22 import static org.junit.jupiter.api.Assertions.assertEquals;
23 import static org.junit.jupiter.api.Assertions.assertNotNull;
24 import static org.junit.jupiter.api.Assertions.assertNull;
25 import static org.junit.jupiter.api.Assertions.assertThrows;
26 import static org.junit.jupiter.api.Assertions.assertTrue;
27 import static org.junit.jupiter.api.Assertions.fail;
28 
29 import java.io.IOException;
30 import java.nio.charset.StandardCharsets;
31 import java.util.HashMap;
32 import java.util.Map;
33 import javax.security.auth.callback.Callback;
34 import javax.security.auth.callback.CallbackHandler;
35 import javax.security.auth.callback.NameCallback;
36 import javax.security.auth.callback.PasswordCallback;
37 import javax.security.auth.callback.UnsupportedCallbackException;
38 import javax.security.sasl.AuthorizeCallback;
39 import javax.security.sasl.RealmCallback;
40 import javax.security.sasl.Sasl;
41 import javax.security.sasl.SaslClient;
42 import javax.security.sasl.SaslClientFactory;
43 import javax.security.sasl.SaslException;
44 import javax.security.sasl.SaslServer;
45 import javax.security.sasl.SaslServerFactory;
46 import org.apache.thrift.TConfiguration;
47 import org.apache.thrift.TProcessor;
48 import org.apache.thrift.protocol.TProtocolFactory;
49 import org.apache.thrift.server.ServerTestBase;
50 import org.apache.thrift.server.TServer;
51 import org.apache.thrift.server.TServer.Args;
52 import org.apache.thrift.server.TSimpleServer;
53 import org.junit.jupiter.api.Test;
54 import org.slf4j.Logger;
55 import org.slf4j.LoggerFactory;
56 
57 public class TestTSaslTransports {
58 
59   private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
60 
61   public static final String HOST = "localhost";
62   public static final String SERVICE = "thrift-test";
63   public static final String PRINCIPAL = "thrift-test-principal";
64   public static final String PASSWORD = "super secret password";
65   public static final String REALM = "thrift-test-realm";
66 
67   public static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
68   public static final Map<String, String> UNWRAPPED_PROPS = null;
69 
70   public static final String WRAPPED_MECHANISM = "DIGEST-MD5";
71   public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
72 
73   static {
WRAPPED_PROPS.put(Sasl.QOP, R)74     WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
75     WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
76   }
77 
78   private static final String testMessage1 =
79       "Hello, world! Also, four "
80           + "score and seven years ago our fathers brought forth on this "
81           + "continent a new nation, conceived in liberty, and dedicated to the "
82           + "proposition that all men are created equal.";
83 
84   private static final String testMessage2 =
85       "I have a dream that one day "
86           + "this nation will rise up and live out the true meaning of its creed: "
87           + "'We hold these truths to be self-evident, that all men are created equal.'";
88 
89   public static class TestSaslCallbackHandler implements CallbackHandler {
90     private final String password;
91 
TestSaslCallbackHandler(String password)92     public TestSaslCallbackHandler(String password) {
93       this.password = password;
94     }
95 
96     @Override
handle(Callback[] callbacks)97     public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
98       for (Callback c : callbacks) {
99         if (c instanceof NameCallback) {
100           ((NameCallback) c).setName(PRINCIPAL);
101         } else if (c instanceof PasswordCallback) {
102           ((PasswordCallback) c).setPassword(password.toCharArray());
103         } else if (c instanceof AuthorizeCallback) {
104           ((AuthorizeCallback) c).setAuthorized(true);
105         } else if (c instanceof RealmCallback) {
106           ((RealmCallback) c).setText(REALM);
107         } else {
108           throw new UnsupportedCallbackException(c);
109         }
110       }
111     }
112   }
113 
114   private static class ServerThread extends Thread {
115     final String mechanism;
116     final Map<String, String> props;
117     volatile Throwable thrown;
118 
ServerThread(String mechanism, Map<String, String> props)119     public ServerThread(String mechanism, Map<String, String> props) {
120       this.mechanism = mechanism;
121       this.props = props;
122     }
123 
run()124     public void run() {
125       try {
126         internalRun();
127       } catch (Throwable t) {
128         thrown = t;
129       }
130     }
131 
internalRun()132     private void internalRun() throws Exception {
133       try (TServerSocket serverSocket =
134           new TServerSocket(
135               new TServerSocket.ServerSocketTransportArgs().port(ServerTestBase.PORT))) {
136         acceptAndWrite(serverSocket);
137       }
138     }
139 
acceptAndWrite(TServerSocket serverSocket)140     private void acceptAndWrite(TServerSocket serverSocket) throws Exception {
141       TTransport serverTransport = serverSocket.accept();
142       TTransport saslServerTransport =
143           new TSaslServerTransport(
144               mechanism,
145               SERVICE,
146               HOST,
147               props,
148               new TestSaslCallbackHandler(PASSWORD),
149               serverTransport);
150 
151       saslServerTransport.open();
152 
153       byte[] inBuf = new byte[testMessage1.getBytes().length];
154       // Deliberately read less than the full buffer to ensure
155       // that TSaslTransport is correctly buffering reads. This
156       // will fail for the WRAPPED test, if it doesn't work.
157       saslServerTransport.readAll(inBuf, 0, 5);
158       saslServerTransport.readAll(inBuf, 5, 10);
159       saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
160       LOGGER.debug("server got: {}", new String(inBuf));
161       assertEquals(new String(inBuf), testMessage1);
162 
163       LOGGER.debug("server writing: {}", testMessage2);
164       saslServerTransport.write(testMessage2.getBytes());
165       saslServerTransport.flush();
166 
167       saslServerTransport.close();
168     }
169   }
170 
testSaslOpen(final String mechanism, final Map<String, String> props)171   private void testSaslOpen(final String mechanism, final Map<String, String> props)
172       throws Exception {
173     ServerThread serverThread = new ServerThread(mechanism, props);
174     serverThread.start();
175 
176     try {
177       Thread.sleep(1000);
178     } catch (InterruptedException e) {
179       // Ah well.
180     }
181 
182     try {
183       TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
184       TTransport saslClientTransport =
185           new TSaslClientTransport(
186               mechanism,
187               PRINCIPAL,
188               SERVICE,
189               HOST,
190               props,
191               new TestSaslCallbackHandler(PASSWORD),
192               clientSocket);
193       saslClientTransport.open();
194       LOGGER.debug("client writing: {}", testMessage1);
195       saslClientTransport.write(testMessage1.getBytes());
196       saslClientTransport.flush();
197 
198       byte[] inBuf = new byte[testMessage2.getBytes().length];
199       saslClientTransport.readAll(inBuf, 0, inBuf.length);
200       LOGGER.debug("client got: {}", new String(inBuf));
201       assertEquals(new String(inBuf), testMessage2);
202 
203       TTransportException expectedException = null;
204       try {
205         saslClientTransport.open();
206       } catch (TTransportException e) {
207         expectedException = e;
208       }
209       assertNotNull(expectedException);
210 
211       saslClientTransport.close();
212     } catch (Exception e) {
213       LOGGER.warn("Exception caught", e);
214       throw e;
215     } finally {
216       serverThread.interrupt();
217       try {
218         serverThread.join();
219       } catch (InterruptedException e) {
220         // Ah well.
221       }
222       assertNull(serverThread.thrown);
223     }
224   }
225 
226   @Test
testUnwrappedOpen()227   public void testUnwrappedOpen() throws Exception {
228     testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
229   }
230 
231   @Test
testWrappedOpen()232   public void testWrappedOpen() throws Exception {
233     testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
234   }
235 
236   @Test
testAnonymousOpen()237   public void testAnonymousOpen() throws Exception {
238     testSaslOpen("ANONYMOUS", null);
239   }
240 
241   /**
242    * Test that we get the proper exceptions thrown back the server when the client provides invalid
243    * password.
244    */
245   @Test
testBadPassword()246   public void testBadPassword() throws Exception {
247     ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
248     serverThread.start();
249 
250     try {
251       Thread.sleep(1000);
252     } catch (InterruptedException e) {
253       // Ah well.
254     }
255 
256     TTransportException tte =
257         assertThrows(
258             TTransportException.class,
259             () -> {
260               TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
261               TTransport saslClientTransport =
262                   new TSaslClientTransport(
263                       UNWRAPPED_MECHANISM,
264                       PRINCIPAL,
265                       SERVICE,
266                       HOST,
267                       UNWRAPPED_PROPS,
268                       new TestSaslCallbackHandler("NOT THE PASSWORD"),
269                       clientSocket);
270               saslClientTransport.open();
271             },
272             "Was able to open transport with bad password");
273     LOGGER.error("Exception for bad password", tte);
274     assertNotNull(tte.getMessage());
275     assertTrue(tte.getMessage().contains("Invalid response"));
276     serverThread.interrupt();
277     serverThread.join();
278     assertNotNull(serverThread.thrown);
279     assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
280   }
281 
282   @Test
testWithServer()283   public void testWithServer() throws Exception {
284     new TestTSaslTransportsWithServer().testIt();
285   }
286 
287   public static class TestTSaslTransportsWithServer extends ServerTestBase {
288 
289     private Thread serverThread;
290     private TServer server;
291 
292     @Override
getClientTransport(TTransport underlyingTransport)293     public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
294       return new TSaslClientTransport(
295           WRAPPED_MECHANISM,
296           PRINCIPAL,
297           SERVICE,
298           HOST,
299           WRAPPED_PROPS,
300           new TestSaslCallbackHandler(PASSWORD),
301           underlyingTransport);
302     }
303 
304     @Override
startServer( final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)305     public void startServer(
306         final TProcessor processor,
307         final TProtocolFactory protoFactory,
308         final TTransportFactory factory)
309         throws Exception {
310       serverThread =
311           new Thread() {
312             public void run() {
313               try {
314                 // Transport
315                 TServerSocket socket =
316                     new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT));
317 
318                 TTransportFactory factory =
319                     new TSaslServerTransport.Factory(
320                         WRAPPED_MECHANISM,
321                         SERVICE,
322                         HOST,
323                         WRAPPED_PROPS,
324                         new TestSaslCallbackHandler(PASSWORD));
325                 server =
326                     new TSimpleServer(
327                         new Args(socket)
328                             .processor(processor)
329                             .transportFactory(factory)
330                             .protocolFactory(protoFactory));
331 
332                 // Run it
333                 LOGGER.debug("Starting the server on port {}", PORT);
334                 server.serve();
335               } catch (Exception e) {
336                 e.printStackTrace();
337                 fail(e);
338               }
339             }
340           };
341       serverThread.start();
342       Thread.sleep(1000);
343     }
344 
345     @Override
stopServer()346     public void stopServer() throws Exception {
347       server.stop();
348       try {
349         serverThread.join();
350       } catch (InterruptedException e) {
351         LOGGER.debug("interrupted during sleep", e);
352       }
353     }
354   }
355 
356   /** Implementation of SASL ANONYMOUS, used for testing client-side initial responses. */
357   private static class AnonymousClient implements SaslClient {
358     private final String username;
359     private boolean hasProvidedInitialResponse;
360 
AnonymousClient(String username)361     public AnonymousClient(String username) {
362       this.username = username;
363     }
364 
365     @Override
getMechanismName()366     public String getMechanismName() {
367       return "ANONYMOUS";
368     }
369 
370     @Override
hasInitialResponse()371     public boolean hasInitialResponse() {
372       return true;
373     }
374 
375     @Override
evaluateChallenge(byte[] challenge)376     public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
377       if (hasProvidedInitialResponse) {
378         throw new SaslException("Already complete!");
379       }
380 
381       hasProvidedInitialResponse = true;
382       return username.getBytes(StandardCharsets.UTF_8);
383     }
384 
385     @Override
isComplete()386     public boolean isComplete() {
387       return hasProvidedInitialResponse;
388     }
389 
390     @Override
unwrap(byte[] incoming, int offset, int len)391     public byte[] unwrap(byte[] incoming, int offset, int len) {
392       throw new UnsupportedOperationException();
393     }
394 
395     @Override
wrap(byte[] outgoing, int offset, int len)396     public byte[] wrap(byte[] outgoing, int offset, int len) {
397       throw new UnsupportedOperationException();
398     }
399 
400     @Override
getNegotiatedProperty(String propName)401     public Object getNegotiatedProperty(String propName) {
402       return null;
403     }
404 
405     @Override
dispose()406     public void dispose() {}
407   }
408 
409   private static class AnonymousServer implements SaslServer {
410     private String user;
411 
412     @Override
getMechanismName()413     public String getMechanismName() {
414       return "ANONYMOUS";
415     }
416 
417     @Override
evaluateResponse(byte[] response)418     public byte[] evaluateResponse(byte[] response) throws SaslException {
419       this.user = new String(response, StandardCharsets.UTF_8);
420       return null;
421     }
422 
423     @Override
isComplete()424     public boolean isComplete() {
425       return user != null;
426     }
427 
428     @Override
getAuthorizationID()429     public String getAuthorizationID() {
430       return user;
431     }
432 
433     @Override
unwrap(byte[] incoming, int offset, int len)434     public byte[] unwrap(byte[] incoming, int offset, int len) {
435       throw new UnsupportedOperationException();
436     }
437 
438     @Override
wrap(byte[] outgoing, int offset, int len)439     public byte[] wrap(byte[] outgoing, int offset, int len) {
440       throw new UnsupportedOperationException();
441     }
442 
443     @Override
getNegotiatedProperty(String propName)444     public Object getNegotiatedProperty(String propName) {
445       return null;
446     }
447 
448     @Override
dispose()449     public void dispose() {}
450   }
451 
452   public static class SaslAnonymousFactory implements SaslClientFactory, SaslServerFactory {
453 
454     @Override
createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)455     public SaslClient createSaslClient(
456         String[] mechanisms,
457         String authorizationId,
458         String protocol,
459         String serverName,
460         Map<String, ?> props,
461         CallbackHandler cbh) {
462       for (String mech : mechanisms) {
463         if ("ANONYMOUS".equals(mech)) {
464           return new AnonymousClient(authorizationId);
465         }
466       }
467       return null;
468     }
469 
470     @Override
createSaslServer( String mechanism, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)471     public SaslServer createSaslServer(
472         String mechanism,
473         String protocol,
474         String serverName,
475         Map<String, ?> props,
476         CallbackHandler cbh) {
477       if ("ANONYMOUS".equals(mechanism)) {
478         return new AnonymousServer();
479       }
480       return null;
481     }
482 
483     @Override
getMechanismNames(Map<String, ?> props)484     public String[] getMechanismNames(Map<String, ?> props) {
485       return new String[] {"ANONYMOUS"};
486     }
487   }
488 
489   static {
java.security.Security.addProvider(new SaslAnonymousProvider())490     java.security.Security.addProvider(new SaslAnonymousProvider());
491   }
492 
493   public static class SaslAnonymousProvider extends java.security.Provider {
SaslAnonymousProvider()494     public SaslAnonymousProvider() {
495       super("ThriftSaslAnonymous", "1.0", "Thrift Anonymous SASL provider");
496       put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
497       put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
498     }
499   }
500 
501   private static class MockTTransport extends TTransport {
502 
503     byte[] badHeader = null;
504     private final TMemoryInputTransport readBuffer;
505 
MockTTransport(int mode)506     public MockTTransport(int mode) throws TTransportException {
507       readBuffer = new TMemoryInputTransport();
508       if (mode == 1) {
509         // Invalid status byte
510         badHeader = new byte[] {(byte) 0xFF, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x05};
511       } else if (mode == 2) {
512         // Valid status byte, negative payload length
513         badHeader = new byte[] {(byte) 0x01, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF};
514       } else if (mode == 3) {
515         // Valid status byte, excessively large, bogus payload length
516         badHeader = new byte[] {(byte) 0x01, (byte) 0x64, (byte) 0x00, (byte) 0x00, (byte) 0x00};
517       }
518       readBuffer.reset(badHeader);
519     }
520 
521     @Override
isOpen()522     public boolean isOpen() {
523       return true;
524     }
525 
526     @Override
open()527     public void open() throws TTransportException {}
528 
529     @Override
close()530     public void close() {}
531 
532     @Override
read(byte[] buf, int off, int len)533     public int read(byte[] buf, int off, int len) throws TTransportException {
534       return readBuffer.read(buf, off, len);
535     }
536 
537     @Override
write(byte[] buf, int off, int len)538     public void write(byte[] buf, int off, int len) throws TTransportException {}
539 
540     @Override
getConfiguration()541     public TConfiguration getConfiguration() {
542       return readBuffer.getConfiguration();
543     }
544 
545     @Override
updateKnownMessageSize(long size)546     public void updateKnownMessageSize(long size) throws TTransportException {
547       readBuffer.updateKnownMessageSize(size);
548     }
549 
550     @Override
checkReadBytesAvailable(long numBytes)551     public void checkReadBytesAvailable(long numBytes) throws TTransportException {
552       readBuffer.checkReadBytesAvailable(numBytes);
553     }
554   }
555 
556   @Test
testBadHeader()557   public void testBadHeader() {
558     TSaslTransport saslTransport;
559     try {
560       saslTransport = new TSaslServerTransport(new MockTTransport(1));
561       saslTransport.receiveSaslMessage();
562       fail("Should have gotten an error due to incorrect status byte value.");
563     } catch (TTransportException e) {
564       assertEquals(e.getMessage(), "Invalid status -1");
565     }
566     try {
567       saslTransport = new TSaslServerTransport(new MockTTransport(2));
568       saslTransport.receiveSaslMessage();
569       fail("Should have gotten an error due to negative payload length.");
570     } catch (TTransportException e) {
571       assertEquals(e.getMessage(), "Invalid payload header length: -1");
572     }
573     try {
574       saslTransport = new TSaslServerTransport(new MockTTransport(3));
575       saslTransport.receiveSaslMessage();
576       fail("Should have gotten an error due to bogus (large) payload length.");
577     } catch (TTransportException e) {
578       assertEquals(e.getMessage(), "Invalid payload header length: 1677721600");
579     }
580   }
581 }
582