1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 package org.apache.thrift.transport; 21 22 import static org.junit.jupiter.api.Assertions.assertEquals; 23 import static org.junit.jupiter.api.Assertions.assertNotNull; 24 import static org.junit.jupiter.api.Assertions.assertNull; 25 import static org.junit.jupiter.api.Assertions.assertThrows; 26 import static org.junit.jupiter.api.Assertions.assertTrue; 27 import static org.junit.jupiter.api.Assertions.fail; 28 29 import java.io.IOException; 30 import java.nio.charset.StandardCharsets; 31 import java.util.HashMap; 32 import java.util.Map; 33 import javax.security.auth.callback.Callback; 34 import javax.security.auth.callback.CallbackHandler; 35 import javax.security.auth.callback.NameCallback; 36 import javax.security.auth.callback.PasswordCallback; 37 import javax.security.auth.callback.UnsupportedCallbackException; 38 import javax.security.sasl.AuthorizeCallback; 39 import javax.security.sasl.RealmCallback; 40 import javax.security.sasl.Sasl; 41 import javax.security.sasl.SaslClient; 42 import javax.security.sasl.SaslClientFactory; 43 import javax.security.sasl.SaslException; 44 import javax.security.sasl.SaslServer; 45 import javax.security.sasl.SaslServerFactory; 46 import org.apache.thrift.TConfiguration; 47 import org.apache.thrift.TProcessor; 48 import org.apache.thrift.protocol.TProtocolFactory; 49 import org.apache.thrift.server.ServerTestBase; 50 import org.apache.thrift.server.TServer; 51 import org.apache.thrift.server.TServer.Args; 52 import org.apache.thrift.server.TSimpleServer; 53 import org.junit.jupiter.api.Test; 54 import org.slf4j.Logger; 55 import org.slf4j.LoggerFactory; 56 57 public class TestTSaslTransports { 58 59 private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); 60 61 public static final String HOST = "localhost"; 62 public static final String SERVICE = "thrift-test"; 63 public static final String PRINCIPAL = "thrift-test-principal"; 64 public static final String PASSWORD = "super secret password"; 65 public static final String REALM = "thrift-test-realm"; 66 67 public static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; 68 public static final Map<String, String> UNWRAPPED_PROPS = null; 69 70 public static final String WRAPPED_MECHANISM = "DIGEST-MD5"; 71 public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>(); 72 73 static { WRAPPED_PROPS.put(Sasl.QOP, R)74 WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); 75 WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); 76 } 77 78 private static final String testMessage1 = 79 "Hello, world! Also, four " 80 + "score and seven years ago our fathers brought forth on this " 81 + "continent a new nation, conceived in liberty, and dedicated to the " 82 + "proposition that all men are created equal."; 83 84 private static final String testMessage2 = 85 "I have a dream that one day " 86 + "this nation will rise up and live out the true meaning of its creed: " 87 + "'We hold these truths to be self-evident, that all men are created equal.'"; 88 89 public static class TestSaslCallbackHandler implements CallbackHandler { 90 private final String password; 91 TestSaslCallbackHandler(String password)92 public TestSaslCallbackHandler(String password) { 93 this.password = password; 94 } 95 96 @Override handle(Callback[] callbacks)97 public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { 98 for (Callback c : callbacks) { 99 if (c instanceof NameCallback) { 100 ((NameCallback) c).setName(PRINCIPAL); 101 } else if (c instanceof PasswordCallback) { 102 ((PasswordCallback) c).setPassword(password.toCharArray()); 103 } else if (c instanceof AuthorizeCallback) { 104 ((AuthorizeCallback) c).setAuthorized(true); 105 } else if (c instanceof RealmCallback) { 106 ((RealmCallback) c).setText(REALM); 107 } else { 108 throw new UnsupportedCallbackException(c); 109 } 110 } 111 } 112 } 113 114 private static class ServerThread extends Thread { 115 final String mechanism; 116 final Map<String, String> props; 117 volatile Throwable thrown; 118 ServerThread(String mechanism, Map<String, String> props)119 public ServerThread(String mechanism, Map<String, String> props) { 120 this.mechanism = mechanism; 121 this.props = props; 122 } 123 run()124 public void run() { 125 try { 126 internalRun(); 127 } catch (Throwable t) { 128 thrown = t; 129 } 130 } 131 internalRun()132 private void internalRun() throws Exception { 133 try (TServerSocket serverSocket = 134 new TServerSocket( 135 new TServerSocket.ServerSocketTransportArgs().port(ServerTestBase.PORT))) { 136 acceptAndWrite(serverSocket); 137 } 138 } 139 acceptAndWrite(TServerSocket serverSocket)140 private void acceptAndWrite(TServerSocket serverSocket) throws Exception { 141 TTransport serverTransport = serverSocket.accept(); 142 TTransport saslServerTransport = 143 new TSaslServerTransport( 144 mechanism, 145 SERVICE, 146 HOST, 147 props, 148 new TestSaslCallbackHandler(PASSWORD), 149 serverTransport); 150 151 saslServerTransport.open(); 152 153 byte[] inBuf = new byte[testMessage1.getBytes().length]; 154 // Deliberately read less than the full buffer to ensure 155 // that TSaslTransport is correctly buffering reads. This 156 // will fail for the WRAPPED test, if it doesn't work. 157 saslServerTransport.readAll(inBuf, 0, 5); 158 saslServerTransport.readAll(inBuf, 5, 10); 159 saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); 160 LOGGER.debug("server got: {}", new String(inBuf)); 161 assertEquals(new String(inBuf), testMessage1); 162 163 LOGGER.debug("server writing: {}", testMessage2); 164 saslServerTransport.write(testMessage2.getBytes()); 165 saslServerTransport.flush(); 166 167 saslServerTransport.close(); 168 } 169 } 170 testSaslOpen(final String mechanism, final Map<String, String> props)171 private void testSaslOpen(final String mechanism, final Map<String, String> props) 172 throws Exception { 173 ServerThread serverThread = new ServerThread(mechanism, props); 174 serverThread.start(); 175 176 try { 177 Thread.sleep(1000); 178 } catch (InterruptedException e) { 179 // Ah well. 180 } 181 182 try { 183 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 184 TTransport saslClientTransport = 185 new TSaslClientTransport( 186 mechanism, 187 PRINCIPAL, 188 SERVICE, 189 HOST, 190 props, 191 new TestSaslCallbackHandler(PASSWORD), 192 clientSocket); 193 saslClientTransport.open(); 194 LOGGER.debug("client writing: {}", testMessage1); 195 saslClientTransport.write(testMessage1.getBytes()); 196 saslClientTransport.flush(); 197 198 byte[] inBuf = new byte[testMessage2.getBytes().length]; 199 saslClientTransport.readAll(inBuf, 0, inBuf.length); 200 LOGGER.debug("client got: {}", new String(inBuf)); 201 assertEquals(new String(inBuf), testMessage2); 202 203 TTransportException expectedException = null; 204 try { 205 saslClientTransport.open(); 206 } catch (TTransportException e) { 207 expectedException = e; 208 } 209 assertNotNull(expectedException); 210 211 saslClientTransport.close(); 212 } catch (Exception e) { 213 LOGGER.warn("Exception caught", e); 214 throw e; 215 } finally { 216 serverThread.interrupt(); 217 try { 218 serverThread.join(); 219 } catch (InterruptedException e) { 220 // Ah well. 221 } 222 assertNull(serverThread.thrown); 223 } 224 } 225 226 @Test testUnwrappedOpen()227 public void testUnwrappedOpen() throws Exception { 228 testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 229 } 230 231 @Test testWrappedOpen()232 public void testWrappedOpen() throws Exception { 233 testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); 234 } 235 236 @Test testAnonymousOpen()237 public void testAnonymousOpen() throws Exception { 238 testSaslOpen("ANONYMOUS", null); 239 } 240 241 /** 242 * Test that we get the proper exceptions thrown back the server when the client provides invalid 243 * password. 244 */ 245 @Test testBadPassword()246 public void testBadPassword() throws Exception { 247 ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); 248 serverThread.start(); 249 250 try { 251 Thread.sleep(1000); 252 } catch (InterruptedException e) { 253 // Ah well. 254 } 255 256 TTransportException tte = 257 assertThrows( 258 TTransportException.class, 259 () -> { 260 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); 261 TTransport saslClientTransport = 262 new TSaslClientTransport( 263 UNWRAPPED_MECHANISM, 264 PRINCIPAL, 265 SERVICE, 266 HOST, 267 UNWRAPPED_PROPS, 268 new TestSaslCallbackHandler("NOT THE PASSWORD"), 269 clientSocket); 270 saslClientTransport.open(); 271 }, 272 "Was able to open transport with bad password"); 273 LOGGER.error("Exception for bad password", tte); 274 assertNotNull(tte.getMessage()); 275 assertTrue(tte.getMessage().contains("Invalid response")); 276 serverThread.interrupt(); 277 serverThread.join(); 278 assertNotNull(serverThread.thrown); 279 assertTrue(serverThread.thrown.getMessage().contains("Invalid response")); 280 } 281 282 @Test testWithServer()283 public void testWithServer() throws Exception { 284 new TestTSaslTransportsWithServer().testIt(); 285 } 286 287 public static class TestTSaslTransportsWithServer extends ServerTestBase { 288 289 private Thread serverThread; 290 private TServer server; 291 292 @Override getClientTransport(TTransport underlyingTransport)293 public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { 294 return new TSaslClientTransport( 295 WRAPPED_MECHANISM, 296 PRINCIPAL, 297 SERVICE, 298 HOST, 299 WRAPPED_PROPS, 300 new TestSaslCallbackHandler(PASSWORD), 301 underlyingTransport); 302 } 303 304 @Override startServer( final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory)305 public void startServer( 306 final TProcessor processor, 307 final TProtocolFactory protoFactory, 308 final TTransportFactory factory) 309 throws Exception { 310 serverThread = 311 new Thread() { 312 public void run() { 313 try { 314 // Transport 315 TServerSocket socket = 316 new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT)); 317 318 TTransportFactory factory = 319 new TSaslServerTransport.Factory( 320 WRAPPED_MECHANISM, 321 SERVICE, 322 HOST, 323 WRAPPED_PROPS, 324 new TestSaslCallbackHandler(PASSWORD)); 325 server = 326 new TSimpleServer( 327 new Args(socket) 328 .processor(processor) 329 .transportFactory(factory) 330 .protocolFactory(protoFactory)); 331 332 // Run it 333 LOGGER.debug("Starting the server on port {}", PORT); 334 server.serve(); 335 } catch (Exception e) { 336 e.printStackTrace(); 337 fail(e); 338 } 339 } 340 }; 341 serverThread.start(); 342 Thread.sleep(1000); 343 } 344 345 @Override stopServer()346 public void stopServer() throws Exception { 347 server.stop(); 348 try { 349 serverThread.join(); 350 } catch (InterruptedException e) { 351 LOGGER.debug("interrupted during sleep", e); 352 } 353 } 354 } 355 356 /** Implementation of SASL ANONYMOUS, used for testing client-side initial responses. */ 357 private static class AnonymousClient implements SaslClient { 358 private final String username; 359 private boolean hasProvidedInitialResponse; 360 AnonymousClient(String username)361 public AnonymousClient(String username) { 362 this.username = username; 363 } 364 365 @Override getMechanismName()366 public String getMechanismName() { 367 return "ANONYMOUS"; 368 } 369 370 @Override hasInitialResponse()371 public boolean hasInitialResponse() { 372 return true; 373 } 374 375 @Override evaluateChallenge(byte[] challenge)376 public byte[] evaluateChallenge(byte[] challenge) throws SaslException { 377 if (hasProvidedInitialResponse) { 378 throw new SaslException("Already complete!"); 379 } 380 381 hasProvidedInitialResponse = true; 382 return username.getBytes(StandardCharsets.UTF_8); 383 } 384 385 @Override isComplete()386 public boolean isComplete() { 387 return hasProvidedInitialResponse; 388 } 389 390 @Override unwrap(byte[] incoming, int offset, int len)391 public byte[] unwrap(byte[] incoming, int offset, int len) { 392 throw new UnsupportedOperationException(); 393 } 394 395 @Override wrap(byte[] outgoing, int offset, int len)396 public byte[] wrap(byte[] outgoing, int offset, int len) { 397 throw new UnsupportedOperationException(); 398 } 399 400 @Override getNegotiatedProperty(String propName)401 public Object getNegotiatedProperty(String propName) { 402 return null; 403 } 404 405 @Override dispose()406 public void dispose() {} 407 } 408 409 private static class AnonymousServer implements SaslServer { 410 private String user; 411 412 @Override getMechanismName()413 public String getMechanismName() { 414 return "ANONYMOUS"; 415 } 416 417 @Override evaluateResponse(byte[] response)418 public byte[] evaluateResponse(byte[] response) throws SaslException { 419 this.user = new String(response, StandardCharsets.UTF_8); 420 return null; 421 } 422 423 @Override isComplete()424 public boolean isComplete() { 425 return user != null; 426 } 427 428 @Override getAuthorizationID()429 public String getAuthorizationID() { 430 return user; 431 } 432 433 @Override unwrap(byte[] incoming, int offset, int len)434 public byte[] unwrap(byte[] incoming, int offset, int len) { 435 throw new UnsupportedOperationException(); 436 } 437 438 @Override wrap(byte[] outgoing, int offset, int len)439 public byte[] wrap(byte[] outgoing, int offset, int len) { 440 throw new UnsupportedOperationException(); 441 } 442 443 @Override getNegotiatedProperty(String propName)444 public Object getNegotiatedProperty(String propName) { 445 return null; 446 } 447 448 @Override dispose()449 public void dispose() {} 450 } 451 452 public static class SaslAnonymousFactory implements SaslClientFactory, SaslServerFactory { 453 454 @Override createSaslClient( String[] mechanisms, String authorizationId, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)455 public SaslClient createSaslClient( 456 String[] mechanisms, 457 String authorizationId, 458 String protocol, 459 String serverName, 460 Map<String, ?> props, 461 CallbackHandler cbh) { 462 for (String mech : mechanisms) { 463 if ("ANONYMOUS".equals(mech)) { 464 return new AnonymousClient(authorizationId); 465 } 466 } 467 return null; 468 } 469 470 @Override createSaslServer( String mechanism, String protocol, String serverName, Map<String, ?> props, CallbackHandler cbh)471 public SaslServer createSaslServer( 472 String mechanism, 473 String protocol, 474 String serverName, 475 Map<String, ?> props, 476 CallbackHandler cbh) { 477 if ("ANONYMOUS".equals(mechanism)) { 478 return new AnonymousServer(); 479 } 480 return null; 481 } 482 483 @Override getMechanismNames(Map<String, ?> props)484 public String[] getMechanismNames(Map<String, ?> props) { 485 return new String[] {"ANONYMOUS"}; 486 } 487 } 488 489 static { java.security.Security.addProvider(new SaslAnonymousProvider())490 java.security.Security.addProvider(new SaslAnonymousProvider()); 491 } 492 493 public static class SaslAnonymousProvider extends java.security.Provider { SaslAnonymousProvider()494 public SaslAnonymousProvider() { 495 super("ThriftSaslAnonymous", "1.0", "Thrift Anonymous SASL provider"); 496 put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 497 put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); 498 } 499 } 500 501 private static class MockTTransport extends TTransport { 502 503 byte[] badHeader = null; 504 private final TMemoryInputTransport readBuffer; 505 MockTTransport(int mode)506 public MockTTransport(int mode) throws TTransportException { 507 readBuffer = new TMemoryInputTransport(); 508 if (mode == 1) { 509 // Invalid status byte 510 badHeader = new byte[] {(byte) 0xFF, (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x05}; 511 } else if (mode == 2) { 512 // Valid status byte, negative payload length 513 badHeader = new byte[] {(byte) 0x01, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; 514 } else if (mode == 3) { 515 // Valid status byte, excessively large, bogus payload length 516 badHeader = new byte[] {(byte) 0x01, (byte) 0x64, (byte) 0x00, (byte) 0x00, (byte) 0x00}; 517 } 518 readBuffer.reset(badHeader); 519 } 520 521 @Override isOpen()522 public boolean isOpen() { 523 return true; 524 } 525 526 @Override open()527 public void open() throws TTransportException {} 528 529 @Override close()530 public void close() {} 531 532 @Override read(byte[] buf, int off, int len)533 public int read(byte[] buf, int off, int len) throws TTransportException { 534 return readBuffer.read(buf, off, len); 535 } 536 537 @Override write(byte[] buf, int off, int len)538 public void write(byte[] buf, int off, int len) throws TTransportException {} 539 540 @Override getConfiguration()541 public TConfiguration getConfiguration() { 542 return readBuffer.getConfiguration(); 543 } 544 545 @Override updateKnownMessageSize(long size)546 public void updateKnownMessageSize(long size) throws TTransportException { 547 readBuffer.updateKnownMessageSize(size); 548 } 549 550 @Override checkReadBytesAvailable(long numBytes)551 public void checkReadBytesAvailable(long numBytes) throws TTransportException { 552 readBuffer.checkReadBytesAvailable(numBytes); 553 } 554 } 555 556 @Test testBadHeader()557 public void testBadHeader() { 558 TSaslTransport saslTransport; 559 try { 560 saslTransport = new TSaslServerTransport(new MockTTransport(1)); 561 saslTransport.receiveSaslMessage(); 562 fail("Should have gotten an error due to incorrect status byte value."); 563 } catch (TTransportException e) { 564 assertEquals(e.getMessage(), "Invalid status -1"); 565 } 566 try { 567 saslTransport = new TSaslServerTransport(new MockTTransport(2)); 568 saslTransport.receiveSaslMessage(); 569 fail("Should have gotten an error due to negative payload length."); 570 } catch (TTransportException e) { 571 assertEquals(e.getMessage(), "Invalid payload header length: -1"); 572 } 573 try { 574 saslTransport = new TSaslServerTransport(new MockTTransport(3)); 575 saslTransport.receiveSaslMessage(); 576 fail("Should have gotten an error due to bogus (large) payload length."); 577 } catch (TTransportException e) { 578 assertEquals(e.getMessage(), "Invalid payload header length: 1677721600"); 579 } 580 } 581 } 582