1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import inspect
21import logging
22import os
23import platform
24import ssl
25import sys
26import tempfile
27import threading
28import unittest
29import warnings
30from contextlib import contextmanager
31
32import _import_local_thrift  # noqa
33
34SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
35ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
36SERVER_PEM = os.path.join(ROOT_DIR, 'test', 'keys', 'server.pem')
37SERVER_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'server.crt')
38SERVER_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'server.key')
39CLIENT_CERT_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.crt')
40CLIENT_KEY_NO_IP = os.path.join(ROOT_DIR, 'test', 'keys', 'client.key')
41CLIENT_CERT = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.crt')
42CLIENT_KEY = os.path.join(ROOT_DIR, 'test', 'keys', 'client_v3.key')
43CLIENT_CA = os.path.join(ROOT_DIR, 'test', 'keys', 'CA.pem')
44
45TEST_CIPHERS = 'DES-CBC3-SHA:ECDHE-RSA-AES128-GCM-SHA256'
46
47
48class ServerAcceptor(threading.Thread):
49    def __init__(self, server, expect_failure=False):
50        super(ServerAcceptor, self).__init__()
51        self.daemon = True
52        self._server = server
53        self._listening = threading.Event()
54        self._port = None
55        self._port_bound = threading.Event()
56        self._client = None
57        self._client_accepted = threading.Event()
58        self._expect_failure = expect_failure
59        frame = inspect.stack(3)[2]
60        self.name = frame[3]
61        del frame
62
63    def run(self):
64        self._server.listen()
65        self._listening.set()
66
67        try:
68            address = self._server.handle.getsockname()
69            if len(address) > 1:
70                # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are
71                # 4-tuples (host, port, ...), but in each case port is in the second slot.
72                self._port = address[1]
73        finally:
74            self._port_bound.set()
75
76        try:
77            self._client = self._server.accept()
78            if self._client:
79                self._client.read(5)  # hello
80                self._client.write(b"there")
81        except Exception:
82            logging.exception('error on server side (%s):' % self.name)
83            if not self._expect_failure:
84                raise
85        finally:
86            self._client_accepted.set()
87
88    def await_listening(self):
89        self._listening.wait()
90
91    @property
92    def port(self):
93        self._port_bound.wait()
94        return self._port
95
96    @property
97    def client(self):
98        self._client_accepted.wait()
99        return self._client
100
101    def close(self):
102        if self._client:
103            self._client.close()
104        self._server.close()
105
106
107# Python 2.6 compat
108class AssertRaises(object):
109    def __init__(self, expected):
110        self._expected = expected
111
112    def __enter__(self):
113        pass
114
115    def __exit__(self, exc_type, exc_value, traceback):
116        if not exc_type or not issubclass(exc_type, self._expected):
117            raise Exception('fail')
118        return True
119
120
121class TSSLSocketTest(unittest.TestCase):
122    def _server_socket(self, **kwargs):
123        return TSSLServerSocket(port=0, **kwargs)
124
125    @contextmanager
126    def _connectable_client(self, server, expect_failure=False, path=None, **client_kwargs):
127        acc = ServerAcceptor(server, expect_failure)
128        try:
129            acc.start()
130            acc.await_listening()
131
132            host, port = ('localhost', acc.port) if path is None else (None, None)
133            client = TSSLSocket(host, port, unix_socket=path, **client_kwargs)
134            yield acc, client
135        finally:
136            acc.close()
137
138    def _assert_connection_failure(self, server, path=None, **client_args):
139        logging.disable(logging.CRITICAL)
140        try:
141            with self._connectable_client(server, True, path=path, **client_args) as (acc, client):
142                # We need to wait for a connection failure, but not too long.  20ms is a tunable
143                # compromise between test speed and stability
144                client.setTimeout(20)
145                with self._assert_raises(TTransportException):
146                    client.open()
147                    client.write(b"hello")
148                    client.read(5)  # b"there"
149        finally:
150            logging.disable(logging.NOTSET)
151
152    def _assert_raises(self, exc):
153        if sys.hexversion >= 0x020700F0:
154            return self.assertRaises(exc)
155        else:
156            return AssertRaises(exc)
157
158    def _assert_connection_success(self, server, path=None, **client_args):
159        with self._connectable_client(server, path=path, **client_args) as (acc, client):
160            try:
161                self.assertFalse(client.isOpen())
162                client.open()
163                self.assertTrue(client.isOpen())
164                client.write(b"hello")
165                self.assertEqual(client.read(5), b"there")
166                self.assertTrue(acc.client is not None)
167            finally:
168                client.close()
169
170    # deprecated feature
171    def test_deprecation(self):
172        with warnings.catch_warnings(record=True) as w:
173            warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
174            TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
175            self.assertEqual(len(w), 1)
176
177        with warnings.catch_warnings(record=True) as w:
178            warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
179            # Deprecated signature
180            # def __init__(self, host='localhost', port=9090, validate=True, ca_certs=None, keyfile=None, certfile=None, unix_socket=None, ciphers=None):
181            TSSLSocket('localhost', 0, True, SERVER_CERT, CLIENT_KEY, CLIENT_CERT, None, TEST_CIPHERS)
182            self.assertEqual(len(w), 7)
183
184        with warnings.catch_warnings(record=True) as w:
185            warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
186            # Deprecated signature
187            # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None):
188            TSSLServerSocket(None, 0, SERVER_PEM, None, TEST_CIPHERS)
189            self.assertEqual(len(w), 3)
190
191    # deprecated feature
192    def test_set_cert_reqs_by_validate(self):
193        with warnings.catch_warnings(record=True) as w:
194            warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
195            c1 = TSSLSocket('localhost', 0, validate=True, ca_certs=SERVER_CERT)
196            self.assertEqual(c1.cert_reqs, ssl.CERT_REQUIRED)
197
198            c1 = TSSLSocket('localhost', 0, validate=False)
199            self.assertEqual(c1.cert_reqs, ssl.CERT_NONE)
200
201            self.assertEqual(len(w), 2)
202
203    # deprecated feature
204    def test_set_validate_by_cert_reqs(self):
205        with warnings.catch_warnings(record=True) as w:
206            warnings.filterwarnings('always', category=DeprecationWarning, module=self.__module__)
207            c1 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_NONE)
208            self.assertFalse(c1.validate)
209
210            c2 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
211            self.assertTrue(c2.validate)
212
213            c3 = TSSLSocket('localhost', 0, cert_reqs=ssl.CERT_OPTIONAL, ca_certs=SERVER_CERT)
214            self.assertTrue(c3.validate)
215
216            self.assertEqual(len(w), 3)
217
218    def test_unix_domain_socket(self):
219        if platform.system() == 'Windows':
220            print('skipping test_unix_domain_socket')
221            return
222        fd, path = tempfile.mkstemp()
223        os.close(fd)
224        os.unlink(path)
225        try:
226            server = self._server_socket(unix_socket=path, keyfile=SERVER_KEY, certfile=SERVER_CERT)
227            self._assert_connection_success(server, path=path, cert_reqs=ssl.CERT_NONE)
228        finally:
229            os.unlink(path)
230
231    def test_server_cert(self):
232        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
233        self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
234
235        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
236        # server cert not in ca_certs
237        self._assert_connection_failure(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=CLIENT_CERT)
238
239        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
240        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE)
241
242    def test_set_server_cert(self):
243        server = self._server_socket(keyfile=SERVER_KEY, certfile=CLIENT_CERT)
244        with self._assert_raises(Exception):
245            server.certfile = 'foo'
246        with self._assert_raises(Exception):
247            server.certfile = None
248        server.certfile = SERVER_CERT
249        self._assert_connection_success(server, cert_reqs=ssl.CERT_REQUIRED, ca_certs=SERVER_CERT)
250
251    def test_client_cert(self):
252        if not _match_has_ipaddress:
253            print('skipping test_client_cert')
254            return
255        server = self._server_socket(
256            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
257            certfile=SERVER_CERT, ca_certs=CLIENT_CERT)
258        self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=SERVER_CERT, keyfile=SERVER_KEY)
259
260        server = self._server_socket(
261            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
262            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
263        self._assert_connection_failure(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT_NO_IP, keyfile=CLIENT_KEY_NO_IP)
264
265        server = self._server_socket(
266            cert_reqs=ssl.CERT_REQUIRED, keyfile=SERVER_KEY,
267            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
268        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
269
270        server = self._server_socket(
271            cert_reqs=ssl.CERT_OPTIONAL, keyfile=SERVER_KEY,
272            certfile=SERVER_CERT, ca_certs=CLIENT_CA)
273        self._assert_connection_success(server, cert_reqs=ssl.CERT_NONE, certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
274
275    def test_ciphers(self):
276        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
277        self._assert_connection_success(server, ca_certs=SERVER_CERT, ciphers=TEST_CIPHERS)
278
279        if not TSSLSocket._has_ciphers:
280            # unittest.skip is not available for Python 2.6
281            print('skipping test_ciphers')
282            return
283        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
284        self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
285
286        server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ciphers=TEST_CIPHERS)
287        self._assert_connection_failure(server, ca_certs=SERVER_CERT, ciphers='NULL')
288
289    def test_ssl2_and_ssl3_disabled(self):
290        if not hasattr(ssl, 'PROTOCOL_SSLv3'):
291            print('PROTOCOL_SSLv3 is not available')
292        else:
293            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
294            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
295
296            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv3)
297            self._assert_connection_failure(server, ca_certs=SERVER_CERT)
298
299        if not hasattr(ssl, 'PROTOCOL_SSLv2'):
300            print('PROTOCOL_SSLv2 is not available')
301        else:
302            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT)
303            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
304
305            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_SSLv2)
306            self._assert_connection_failure(server, ca_certs=SERVER_CERT)
307
308    def test_newer_tls(self):
309        if not TSSLSocket._has_ssl_context:
310            # unittest.skip is not available for Python 2.6
311            print('skipping test_newer_tls')
312            return
313        if not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
314            print('PROTOCOL_TLSv1_2 is not available')
315        else:
316            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
317            self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
318
319        if not hasattr(ssl, 'PROTOCOL_TLSv1_1'):
320            print('PROTOCOL_TLSv1_1 is not available')
321        else:
322            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
323            self._assert_connection_success(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
324
325        if not hasattr(ssl, 'PROTOCOL_TLSv1_1') or not hasattr(ssl, 'PROTOCOL_TLSv1_2'):
326            print('PROTOCOL_TLSv1_1 and/or PROTOCOL_TLSv1_2 is not available')
327        else:
328            server = self._server_socket(keyfile=SERVER_KEY, certfile=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_2)
329            self._assert_connection_failure(server, ca_certs=SERVER_CERT, ssl_version=ssl.PROTOCOL_TLSv1_1)
330
331    def test_ssl_context(self):
332        if not TSSLSocket._has_ssl_context:
333            # unittest.skip is not available for Python 2.6
334            print('skipping test_ssl_context')
335            return
336        server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
337        server_context.load_cert_chain(SERVER_CERT, SERVER_KEY)
338        server_context.load_verify_locations(CLIENT_CA)
339        server_context.verify_mode = ssl.CERT_REQUIRED
340        server = self._server_socket(ssl_context=server_context)
341
342        client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
343        client_context.load_cert_chain(CLIENT_CERT, CLIENT_KEY)
344        client_context.load_verify_locations(SERVER_CERT)
345        client_context.verify_mode = ssl.CERT_REQUIRED
346
347        self._assert_connection_success(server, ssl_context=client_context)
348
349
350if __name__ == '__main__':
351    logging.basicConfig(level=logging.WARN)
352    from thrift.transport.TSSLSocket import TSSLSocket, TSSLServerSocket, _match_has_ipaddress
353    from thrift.transport.TTransport import TTransportException
354
355    unittest.main()
356