1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import inspect 21import logging 22import os 23import platform 24import ssl 25import sys 26import tempfile 27import threading 28import unittest 29import warnings 30from contextlib import contextmanager 31 32import _import_local_thrift # noqa 33 34SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) 35ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) 36SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem') 37SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt') 38SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key') 39CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt') 40CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key') 41CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt') 42CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key') 43CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem') 44 45TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256' 46 47 48class ServerAcceptor(threading.Thread): 49 def __init__(self, server, expect_failure=False): 50 super(ServerAcceptor, self).__init__() 51 self.daemon = True 52 self._server = server 53 self._listening = threading.Event() 54 self._port = None 55 self._port_bound = threading.Event() 56 self._client = None 57 self._client_accepted = threading.Event() 58 self._expect_failure = expect_failure 59 frame = inspect.stack(3)[2] 60 self.name = frame[3] 61 del frame 62 63 def run(self): 64 self._server.listen() 65 self._listening.set() 66 67 try: 68 address = self._server.handle.getsockname() 69 if len(address) > 1: 70 # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are 71 # 4-tuples (host, port, ...), but in each case port is in the second slot. 72 self._port = address[1] 73 finally: 74 self._port_bound.set() 75 76 try: 77 self._client = self._server.accept() 78 if self._client: 79 self._client.read(5) # hello 80 self._client.write(b"there") 81 except Exception: 82 logging.exception('error on server side (%s):' % self.name) 83 if not self._expect_failure: 84 raise 85 finally: 86 self._client_accepted.set() 87 88 def await_listening(self): 89 self._listening.wait() 90 91 @property 92 def port(self): 93 self._port_bound.wait() 94 return self._port 95 96 @property 97 def client(self): 98 self._client_accepted.wait() 99 return self._client 100 101 def close(self): 102 if self._client: 103 self._client.close() 104 self._server.close() 105 106 107# Python 2.6 compat 108class AssertRaises(object): 109 def __init__(self, expected): 110 self._expected = expected 111 112 def __enter__(self): 113 pass 114 115 def __exit__(self, exc_type, exc_value, traceback): 116 if not exc_type or not issubclass(exc_type, self._expected): 117 raise Exception('fail') 118 return True 119 120 121class TSSLSocketTest(unittest.TestCase): 122 def _server_socket(self, **kwargs): 123 return TSSLServerSocket(port=0, **kwargs) 124 125 @contextmanager 126 def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs): 127 acc = ServerAcceptor(server, expect_failure) 128 try: 129 acc.start() 130 acc.await_listening() 131 132 host, port = ('localhost', acc.port) if path is None else (None, None) 133 client = TSSLSocket(host, port, unix_socket=path, **client_kwargs) 134 yield acc, client 135 finally: 136 acc.close() 137 138 def _assert_connection_failure(self, server, path=None, **client_args): 139 logging.disable(logging.CRITICAL) 140 try: 141 with self._connectable_client(server, True, path=path, **client_args) as (acc, client): 142 # We need to wait for a connection failure, but not too long. 20ms is a tunable 143 # compromise between test speed and stability 144 client.setTimeout(20) 145 with self._assert_raises(TTransportException): 146 client.open() 147 client.write(b"hello") 148 client.read(5) # b"there" 149 finally: 150 logging.disable(logging.NOTSET) 151 152 def _assert_raises(self, exc): 153 if sys.hexversion >= 0x020700F0: 154 return self.assertRaises(exc) 155 else: 156 return AssertRaises(exc) 157 158 def _assert_connection_success(self, server, path=None, **client_args): 159 with self._connectable_client(server, path=path, **client_args) as (acc, client): 160 try: 161 self.assertFalse(client.isOpen()) 162 client.open() 163 self.assertTrue(client.isOpen()) 164 client.write(b"hello") 165 self.assertEqual(client.read(5), b"there") 166 self.assertTrue(acc.client is not None) 167 finally: 168 client.close() 169 170 # deprecated feature 171 def test_deprecation(self): 172 with warnings.catch_warnings(record=True) as w: 173 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 174 TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 175 self.assertEqual(len(w), 1) 176 177 with warnings.catch_warnings(record=True) as w: 178 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 179 # Deprecated signature 180 # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None): 181 TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS) 182 self.assertEqual(len(w), 7) 183 184 with warnings.catch_warnings(record=True) as w: 185 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 186 # Deprecated signature 187 # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): 188 TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS) 189 self.assertEqual(len(w), 3) 190 191 # deprecated feature 192 def test_set_cert_reqs_by_validate(self): 193 with warnings.catch_warnings(record=True) as w: 194 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 195 c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT) 196 self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED) 197 198 c1 = TSSLSocket('localhost', 0, validate=False) 199 self.assertEqual(c1.cert_reqs, ssl.CERT_NONE) 200 201 self.assertEqual(len(w), 2) 202 203 # deprecated feature 204 def test_set_validate_by_cert_reqs(self): 205 with warnings.catch_warnings(record=True) as w: 206 warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__) 207 c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE) 208 self.assertFalse(c1.validate) 209 210 c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 211 self.assertTrue(c2.validate) 212 213 c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT) 214 self.assertTrue(c3.validate) 215 216 self.assertEqual(len(w), 3) 217 218 def test_unix_domain_socket(self): 219 if platform.system() == 'Windows': 220 print('skipping test_unix_domain_socket') 221 return 222 fd, path = tempfile.mkstemp() 223 os.close(fd) 224 os.unlink(path) 225 try: 226 server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT) 227 self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE) 228 finally: 229 os.unlink(path) 230 231 def test_server_cert(self): 232 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 233 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 234 235 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 236 # server cert not in ca_certs 237 self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT) 238 239 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 240 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE) 241 242 def test_set_server_cert(self): 243 server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT) 244 with self._assert_raises(Exception): 245 server.certfile = 'foo' 246 with self._assert_raises(Exception): 247 server.certfile = None 248 server.certfile = SERVER_CERT 249 self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT) 250 251 def test_client_cert(self): 252 if not _match_has_ipaddress: 253 print('skipping test_client_cert') 254 return 255 server = self._server_socket( 256 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 257 certfile=SERVER_CERT, ca_certs=CLIENT_CERT) 258 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY) 259 260 server = self._server_socket( 261 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 262 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 263 self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP) 264 265 server = self._server_socket( 266 cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY, 267 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 268 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 269 270 server = self._server_socket( 271 cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY, 272 certfile=SERVER_CERT, ca_certs=CLIENT_CA) 273 self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY) 274 275 def test_ciphers(self): 276 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 277 self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS) 278 279 if not TSSLSocket._has_ciphers: 280 # unittest.skip is not available for Python 2.6 281 print('skipping test_ciphers') 282 return 283 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 284 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 285 286 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS) 287 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL') 288 289 def test_ssl2_and_ssl3_disabled(self): 290 if not hasattr(ssl, 'PROTOCOL_SSLv3'): 291 print('PROTOCOL_SSLv3 is not available') 292 else: 293 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 294 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 295 296 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3) 297 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 298 299 if not hasattr(ssl, 'PROTOCOL_SSLv2'): 300 print('PROTOCOL_SSLv2 is not available') 301 else: 302 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT) 303 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 304 305 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2) 306 self._assert_connection_failure(server, ca_certs=SERVER_CERT) 307 308 def test_newer_tls(self): 309 if not TSSLSocket._has_ssl_context: 310 # unittest.skip is not available for Python 2.6 311 print('skipping test_newer_tls') 312 return 313 if not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 314 print('PROTOCOL_TLSv1_2 is not available') 315 else: 316 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 317 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 318 319 if not hasattr(ssl, 'PROTOCOL_TLSv1_1'): 320 print('PROTOCOL_TLSv1_1 is not available') 321 else: 322 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 323 self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 324 325 if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'): 326 print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available') 327 else: 328 server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2) 329 self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1) 330 331 def test_ssl_context(self): 332 if not TSSLSocket._has_ssl_context: 333 # unittest.skip is not available for Python 2.6 334 print('skipping test_ssl_context') 335 return 336 server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) 337 server_context.load_cert_chain(SERVER_CERT, SERVER_KEY) 338 server_context.load_verify_locations(CLIENT_CA) 339 server_context.verify_mode = ssl.CERT_REQUIRED 340 server = self._server_socket(ssl_context=server_context) 341 342 client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) 343 client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY) 344 client_context.load_verify_locations(SERVER_CERT) 345 client_context.verify_mode = ssl.CERT_REQUIRED 346 347 self._assert_connection_success(server, ssl_context=client_context) 348 349 350if __name__ == '__main__': 351 logging.basicConfig(level=logging.WARN) 352 from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress 353 from thrift.transport.TTransport import TTransportException 354 355 unittest.main() 356