1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
21 #define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
22 
23 // Put this first to avoid WIN32 build failure
24 #include <thrift/transport/TSocket.h>
25 
26 #include <openssl/ssl.h>
27 #include <string>
28 #include <thrift/concurrency/Mutex.h>
29 
30 namespace apache {
31 namespace thrift {
32 namespace transport {
33 
34 class AccessManager;
35 class SSLContext;
36 
37 enum SSLProtocol {
38   SSLTLS  = 0,  // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later.
39 //SSLv2   = 1,  // HORRIBLY INSECURE!
40   SSLv3   = 2,  // Supports SSLv3 only - also horribly insecure!
41   TLSv1_0 = 3,  // Supports TLSv1_0 or later.
42   TLSv1_1 = 4,  // Supports TLSv1_1 or later.
43   TLSv1_2 = 5,  // Supports TLSv1_2 or later.
44   LATEST  = TLSv1_2
45 };
46 
47 #define TSSL_EINTR 0
48 #define TSSL_DATA 1
49 
50 /**
51  * Initialize OpenSSL library.  This function, or some other
52  * equivalent function to initialize OpenSSL, must be called before
53  * TSSLSocket is used.  If you set TSSLSocketFactory to use manual
54  * OpenSSL initialization, you should call this function or otherwise
55  * ensure OpenSSL is initialized yourself.
56  */
57 void initializeOpenSSL();
58 /**
59  * Cleanup OpenSSL library.  This function should be called to clean
60  * up OpenSSL after use of OpenSSL functionality is finished.  If you
61  * set TSSLSocketFactory to use manual OpenSSL initialization, you
62  * should call this function yourself or ensure that whatever
63  * initialized OpenSSL cleans it up too.
64  */
65 void cleanupOpenSSL();
66 
67 /**
68  * OpenSSL implementation for SSL socket interface.
69  */
70 class TSSLSocket : public TSocket {
71 public:
72   ~TSSLSocket() override;
73   /**
74    * TTransport interface.
75    */
76   bool isOpen() const override;
77   bool peek() override;
78   void open() override;
79   void close() override;
80   bool hasPendingDataToRead() override;
81   uint32_t read(uint8_t* buf, uint32_t len) override;
82   void write(const uint8_t* buf, uint32_t len) override;
83   uint32_t write_partial(const uint8_t* buf, uint32_t len) override;
84   void flush() override;
85   /**
86   * Set whether to use client or server side SSL handshake protocol.
87   *
88   * @param flag  Use server side handshake protocol if true.
89   */
server(bool flag)90   void server(bool flag) { server_ = flag; }
91   /**
92    * Determine whether the SSL socket is server or client mode.
93    */
server()94   bool server() const { return server_; }
95   /**
96    * Set AccessManager.
97    *
98    * @param manager  Instance of AccessManager
99    */
access(std::shared_ptr<AccessManager> manager)100   virtual void access(std::shared_ptr<AccessManager> manager) { access_ = manager; }
101   /**
102    * Set eventSafe flag if libevent is used.
103    */
setLibeventSafe()104   void setLibeventSafe() { eventSafe_ = true; }
105   /**
106    * Determines whether SSL Socket is libevent safe or not.
107    */
isLibeventSafe()108   bool isLibeventSafe() const { return eventSafe_; }
109 
110 protected:
111   /**
112    * Constructor.
113    */
114   TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config = nullptr);
115   /**
116    * Constructor with an interrupt signal.
117    */
118   TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener,
119             std::shared_ptr<TConfiguration> config = nullptr);
120   /**
121    * Constructor, create an instance of TSSLSocket given an existing socket.
122    *
123    * @param socket An existing socket
124    */
125   TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config = nullptr);
126   /**
127    * Constructor, create an instance of TSSLSocket given an existing socket that can be interrupted.
128    *
129    * @param socket An existing socket
130    */
131   TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
132             std::shared_ptr<TConfiguration> config = nullptr);
133    /**
134    * Constructor.
135    *
136    * @param host  Remote host name
137    * @param port  Remote port number
138    */
139   TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<TConfiguration> config = nullptr);
140     /**
141   * Constructor with an interrupt signal.
142   *
143   * @param host  Remote host name
144   * @param port  Remote port number
145   */
146     TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener,
147               std::shared_ptr<TConfiguration> config = nullptr);
148   /**
149    * Authorize peer access after SSL handshake completes.
150    */
151   virtual void authorize();
152   /**
153    * Initiate SSL handshake if not already initiated.
154    */
155   void initializeHandshake();
156   /**
157    * Initiate SSL handshake params.
158    */
159   void initializeHandshakeParams();
160   /**
161    * Check if  SSL handshake is completed or not.
162    */
163   bool checkHandshake();
164   /**
165    * Waits for an socket or shutdown event.
166    *
167    * @throw TTransportException::INTERRUPTED if interrupted is signaled.
168    *
169    * @return TSSL_EINTR if EINTR happened on the underlying socket
170    *         TSSL_DATA  if data is available on the socket.
171    */
172   unsigned int waitForEvent(bool wantRead);
173 
174   bool server_;
175   SSL* ssl_;
176   std::shared_ptr<SSLContext> ctx_;
177   std::shared_ptr<AccessManager> access_;
178   friend class TSSLSocketFactory;
179 
180 private:
181   bool handshakeCompleted_;
182   int readRetryCount_;
183   bool eventSafe_;
184 
185   void init();
186 };
187 
188 /**
189  * SSL socket factory. SSL sockets should be created via SSL factory.
190  * The factory will automatically initialize and cleanup openssl as long as
191  * there is a TSSLSocketFactory instantiated, and as long as the static
192  * boolean manualOpenSSLInitialization_ is set to false, the default.
193  *
194  * If you would like to initialize and cleanup openssl yourself, set
195  * manualOpenSSLInitialization_ to true and TSSLSocketFactory will no
196  * longer be responsible for openssl initialization and teardown.
197  *
198  * It is the responsibility of the code using TSSLSocketFactory to
199  * ensure that the factory lifetime exceeds the lifetime of any sockets
200  * it might create.  If this is not guaranteed, a socket may call into
201  * openssl after the socket factory has cleaned up openssl!  This
202  * guarantee is unnecessary if manualOpenSSLInitialization_ is true,
203  * however, since it would be up to the consuming application instead.
204  */
205 class TSSLSocketFactory {
206 public:
207   /**
208    * Constructor/Destructor
209    *
210    * @param protocol The SSL/TLS protocol to use.
211    */
212   TSSLSocketFactory(SSLProtocol protocol = SSLTLS);
213   virtual ~TSSLSocketFactory();
214   /**
215    * Create an instance of TSSLSocket with a fresh new socket.
216    */
217   virtual std::shared_ptr<TSSLSocket> createSocket();
218   /**
219    * Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
220    */
221   virtual std::shared_ptr<TSSLSocket> createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener);
222   /**
223    * Create an instance of TSSLSocket with the given socket.
224    *
225    * @param socket An existing socket.
226    */
227   virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
228   /**
229    * Create an instance of TSSLSocket with the given socket which is interruptable.
230    *
231    * @param socket An existing socket.
232    */
233   virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
234   /**
235   * Create an instance of TSSLSocket.
236   *
237   * @param host  Remote host to be connected to
238   * @param port  Remote port to be connected to
239   */
240   virtual std::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port);
241   /**
242   * Create an instance of TSSLSocket.
243   *
244   * @param host  Remote host to be connected to
245   * @param port  Remote port to be connected to
246   */
247   virtual std::shared_ptr<TSSLSocket> createSocket(const std::string& host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener);
248   /**
249    * Set ciphers to be used in SSL handshake process.
250    *
251    * @param ciphers  A list of ciphers
252    */
253   virtual void ciphers(const std::string& enable);
254   /**
255    * Enable/Disable authentication.
256    *
257    * @param required Require peer to present valid certificate if true
258    */
259   virtual void authenticate(bool required);
260   /**
261    * Load server certificate.
262    *
263    * @param path   Path to the certificate file
264    * @param format Certificate file format
265    */
266   virtual void loadCertificate(const char* path, const char* format = "PEM");
267   virtual void loadCertificateFromBuffer(const char* aCertificate, const char* format = "PEM");
268   /**
269    * Load private key.
270    *
271    * @param path   Path to the private key file
272    * @param format Private key file format
273    */
274   virtual void loadPrivateKey(const char* path, const char* format = "PEM");
275   virtual void loadPrivateKeyFromBuffer(const char* aPrivateKey, const char* format = "PEM");
276   /**
277    * Load trusted certificates from specified file.
278    *
279    * @param path Path to trusted certificate file
280    */
281   virtual void loadTrustedCertificates(const char* path, const char* capath = nullptr);
282   virtual void loadTrustedCertificatesFromBuffer(const char* aCertificate, const char* aChain = nullptr);
283   /**
284    * Default randomize method.
285    */
286   virtual void randomize();
287   /**
288    * Override default OpenSSL password callback with getPassword().
289    */
290   void overrideDefaultPasswordCallback();
291   /**
292    * Set/Unset server mode.
293    *
294    * @param flag  Server mode if true
295    */
server(bool flag)296   virtual void server(bool flag) { server_ = flag; }
297   /**
298    * Determine whether the socket is in server or client mode.
299    *
300    * @return true, if server mode, or, false, if client mode
301    */
server()302   virtual bool server() const { return server_; }
303   /**
304    * Set AccessManager.
305    *
306    * @param manager  The AccessManager instance
307    */
access(std::shared_ptr<AccessManager> manager)308   virtual void access(std::shared_ptr<AccessManager> manager) { access_ = manager; }
setManualOpenSSLInitialization(bool manualOpenSSLInitialization)309   static void setManualOpenSSLInitialization(bool manualOpenSSLInitialization) {
310     manualOpenSSLInitialization_ = manualOpenSSLInitialization;
311   }
312 
313 protected:
314   std::shared_ptr<SSLContext> ctx_;
315 
316   /**
317    * Override this method for custom password callback. It may be called
318    * multiple times at any time during a session as necessary.
319    *
320    * @param password Pass collected password to OpenSSL
321    * @param size     Maximum length of password including NULL character
322    */
getPassword(std::string &,int)323   virtual void getPassword(std::string& /* password */, int /* size */) {}
324 
325 private:
326   bool server_;
327   std::shared_ptr<AccessManager> access_;
328   static concurrency::Mutex mutex_;
329   static uint64_t count_;
330   THRIFT_EXPORT static bool manualOpenSSLInitialization_;
331   void setup(std::shared_ptr<TSSLSocket> ssl);
332   static int passwordCallback(char* password, int size, int, void* data);
333 };
334 
335 /**
336  * SSL exception.
337  */
338 class TSSLException : public TTransportException {
339 public:
TSSLException(const std::string & message)340   TSSLException(const std::string& message)
341     : TTransportException(TTransportException::INTERNAL_ERROR, message) {}
342 
what()343   const char* what() const noexcept override {
344     if (message_.empty()) {
345       return "TSSLException";
346     } else {
347       return message_.c_str();
348     }
349   }
350 };
351 
352 /**
353  * Wrap OpenSSL SSL_CTX into a class.
354  */
355 class SSLContext {
356 public:
357   SSLContext(const SSLProtocol& protocol = SSLTLS);
358   virtual ~SSLContext();
359   SSL* createSSL();
get()360   SSL_CTX* get() { return ctx_; }
361 
362 private:
363   SSL_CTX* ctx_;
364 };
365 
366 /**
367  * Callback interface for access control. It's meant to verify the remote host.
368  * It's constructed when application starts and set to TSSLSocketFactory
369  * instance. It's passed onto all TSSLSocket instances created by this factory
370  * object.
371  */
372 class AccessManager {
373 public:
374   enum Decision {
375     DENY = -1, // deny access
376     SKIP = 0,  // cannot make decision, move on to next (if any)
377     ALLOW = 1  // allow access
378   };
379   /**
380    * Destructor
381    */
382   virtual ~AccessManager() = default;
383   /**
384    * Determine whether the peer should be granted access or not. It's called
385    * once after the SSL handshake completes successfully, before peer certificate
386    * is examined.
387    *
388    * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
389    * not to be verified.
390    *
391    * @param  sa Peer IP address
392    * @return True if the peer is trusted, false otherwise
393    */
verify(const sockaddr_storage &)394   virtual Decision verify(const sockaddr_storage& /* sa */) noexcept { return DENY; }
395   /**
396    * Determine whether the peer should be granted access or not. It's called
397    * every time a DNS subjectAltName/common name is extracted from peer's
398    * certificate.
399    *
400    * @param  host Client mode: host name returned by TSocket::getHost()
401    *              Server mode: host name returned by TSocket::getPeerHost()
402    * @param  name SubjectAltName or common name extracted from peer certificate
403    * @param  size Length of name
404    * @return True if the peer is trusted, false otherwise
405    *
406    * Note: The "name" parameter may be UTF8 encoded.
407    */
verify(const std::string &,const char *,int)408   virtual Decision verify(const std::string& /* host */,
409                           const char* /* name */,
410                           int /* size */) noexcept {
411     return DENY;
412   }
413   /**
414    * Determine whether the peer should be granted access or not. It's called
415    * every time an IP subjectAltName is extracted from peer's certificate.
416    *
417    * @param  sa   Peer IP address retrieved from the underlying socket
418    * @param  data IP address extracted from certificate
419    * @param  size Length of the IP address
420    * @return True if the peer is trusted, false otherwise
421    */
verify(const sockaddr_storage &,const char *,int)422   virtual Decision verify(const sockaddr_storage& /* sa */,
423                           const char* /* data */,
424                           int /* size */) noexcept {
425     return DENY;
426   }
427 };
428 
429 typedef AccessManager::Decision Decision;
430 
431 class DefaultClientAccessManager : public AccessManager {
432 public:
433   // AccessManager interface
434   Decision verify(const sockaddr_storage& sa) noexcept override;
435   Decision verify(const std::string& host, const char* name, int size) noexcept override;
436   Decision verify(const sockaddr_storage& sa, const char* data, int size) noexcept override;
437 };
438 }
439 }
440 }
441 
442 #endif
443