1 /*
2  * Copyright 2006 Facebook
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #ifndef _THRIFT_TRANSPORT_TSSLSOCKET_H_
8 #define _THRIFT_TRANSPORT_TSSLSOCKET_H_ 1
9 
10 // Put this first to avoid WIN32 build failure
11 #include <thrift/transport/TSocket.h>
12 
13 #include <string>
14 #include <thrift/concurrency/Mutex.h>
15 
16 #include <zephyr/posix/sys/socket.h>
17 
18 namespace apache
19 {
20 namespace thrift
21 {
22 namespace transport
23 {
24 
25 class AccessManager;
26 class SSLContext;
27 
28 enum SSLProtocol {
29 	SSLTLS = 0,  // Supports SSLv2 and SSLv3 handshake but only negotiates at TLSv1_0 or later.
30 		     // SSLv2   = 1,  // HORRIBLY INSECURE!
31 	SSLv3 = 2,   // Supports SSLv3 only - also horribly insecure!
32 	TLSv1_0 = 3, // Supports TLSv1_0 or later.
33 	TLSv1_1 = 4, // Supports TLSv1_1 or later.
34 	TLSv1_2 = 5, // Supports TLSv1_2 or later.
35 	LATEST = TLSv1_2
36 };
37 
38 #define TSSL_EINTR 0
39 #define TSSL_DATA  1
40 
41 /**
42  * Initialize OpenSSL library.  This function, or some other
43  * equivalent function to initialize OpenSSL, must be called before
44  * TSSLSocket is used.  If you set TSSLSocketFactory to use manual
45  * OpenSSL initialization, you should call this function or otherwise
46  * ensure OpenSSL is initialized yourself.
47  */
48 void initializeOpenSSL();
49 /**
50  * Cleanup OpenSSL library.  This function should be called to clean
51  * up OpenSSL after use of OpenSSL functionality is finished.  If you
52  * set TSSLSocketFactory to use manual OpenSSL initialization, you
53  * should call this function yourself or ensure that whatever
54  * initialized OpenSSL cleans it up too.
55  */
56 void cleanupOpenSSL();
57 
58 /**
59  * OpenSSL implementation for SSL socket interface.
60  */
61 class TSSLSocket : public TSocket
62 {
63 public:
64 	~TSSLSocket() override;
65 	/**
66 	 * TTransport interface.
67 	 */
68 	void open() override;
69 	/**
70 	 * Set whether to use client or server side SSL handshake protocol.
71 	 *
72 	 * @param flag  Use server side handshake protocol if true.
73 	 */
server(bool flag)74 	void server(bool flag)
75 	{
76 		server_ = flag;
77 	}
78 	/**
79 	 * Determine whether the SSL socket is server or client mode.
80 	 */
server()81 	bool server() const
82 	{
83 		return server_;
84 	}
85 	/**
86 	 * Set AccessManager.
87 	 *
88 	 * @param manager  Instance of AccessManager
89 	 */
access(std::shared_ptr<AccessManager> manager)90 	virtual void access(std::shared_ptr<AccessManager> manager)
91 	{
92 		access_ = manager;
93 	}
94 	/**
95 	 * Set eventSafe flag if libevent is used.
96 	 */
setLibeventSafe()97 	void setLibeventSafe()
98 	{
99 		eventSafe_ = true;
100 	}
101 	/**
102 	 * Determines whether SSL Socket is libevent safe or not.
103 	 */
isLibeventSafe()104 	bool isLibeventSafe() const
105 	{
106 		return eventSafe_;
107 	}
108 
109 	void authenticate(bool required);
110 
111 protected:
112 	/**
113 	 * Constructor.
114 	 */
115 	TSSLSocket(std::shared_ptr<SSLContext> ctx,
116 		   std::shared_ptr<TConfiguration> config = nullptr);
117 	/**
118 	 * Constructor with an interrupt signal.
119 	 */
120 	TSSLSocket(std::shared_ptr<SSLContext> ctx,
121 		   std::shared_ptr<THRIFT_SOCKET> interruptListener,
122 		   std::shared_ptr<TConfiguration> config = nullptr);
123 	/**
124 	 * Constructor, create an instance of TSSLSocket given an existing socket.
125 	 *
126 	 * @param socket An existing socket
127 	 */
128 	TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
129 		   std::shared_ptr<TConfiguration> config = nullptr);
130 	/**
131 	 * Constructor, create an instance of TSSLSocket given an existing socket that can be
132 	 * interrupted.
133 	 *
134 	 * @param socket An existing socket
135 	 */
136 	TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
137 		   std::shared_ptr<THRIFT_SOCKET> interruptListener,
138 		   std::shared_ptr<TConfiguration> config = nullptr);
139 	/**
140 	 * Constructor.
141 	 *
142 	 * @param host  Remote host name
143 	 * @param port  Remote port number
144 	 */
145 	TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port,
146 		   std::shared_ptr<TConfiguration> config = nullptr);
147 	/**
148 	 * Constructor with an interrupt signal.
149 	 *
150 	 * @param host  Remote host name
151 	 * @param port  Remote port number
152 	 */
153 	TSSLSocket(std::shared_ptr<SSLContext> ctx, std::string host, int port,
154 		   std::shared_ptr<THRIFT_SOCKET> interruptListener,
155 		   std::shared_ptr<TConfiguration> config = nullptr);
156 	/**
157 	 * Authorize peer access after SSL handshake completes.
158 	 */
159 	virtual void authorize();
160 	/**
161 	 * Initiate SSL handshake if not already initiated.
162 	 */
163 	void initializeHandshake();
164 	/**
165 	 * Initiate SSL handshake params.
166 	 */
167 	void initializeHandshakeParams();
168 	/**
169 	 * Check if  SSL handshake is completed or not.
170 	 */
171 	bool checkHandshake();
172 	/**
173 	 * Waits for an socket or shutdown event.
174 	 *
175 	 * @throw TTransportException::INTERRUPTED if interrupted is signaled.
176 	 *
177 	 * @return TSSL_EINTR if EINTR happened on the underlying socket
178 	 *         TSSL_DATA  if data is available on the socket.
179 	 */
180 	unsigned int waitForEvent(bool wantRead);
181 
182 	void openSecConnection(struct addrinfo *res);
183 
184 	bool server_;
185 	std::shared_ptr<SSLContext> ctx_;
186 	std::shared_ptr<AccessManager> access_;
187 	friend class TSSLSocketFactory;
188 
189 private:
190 	bool handshakeCompleted_;
191 	int readRetryCount_;
192 	bool eventSafe_;
193 
194 	void init();
195 };
196 
197 /**
198  * SSL socket factory. SSL sockets should be created via SSL factory.
199  * The factory will automatically initialize and cleanup openssl as long as
200  * there is a TSSLSocketFactory instantiated, and as long as the static
201  * boolean manualOpenSSLInitialization_ is set to false, the default.
202  *
203  * If you would like to initialize and cleanup openssl yourself, set
204  * manualOpenSSLInitialization_ to true and TSSLSocketFactory will no
205  * longer be responsible for openssl initialization and teardown.
206  *
207  * It is the responsibility of the code using TSSLSocketFactory to
208  * ensure that the factory lifetime exceeds the lifetime of any sockets
209  * it might create.  If this is not guaranteed, a socket may call into
210  * openssl after the socket factory has cleaned up openssl!  This
211  * guarantee is unnecessary if manualOpenSSLInitialization_ is true,
212  * however, since it would be up to the consuming application instead.
213  */
214 class TSSLSocketFactory
215 {
216       public:
217 	/**
218 	 * Constructor/Destructor
219 	 *
220 	 * @param protocol The SSL/TLS protocol to use.
221 	 */
222 	TSSLSocketFactory(SSLProtocol protocol = SSLTLS);
223 	virtual ~TSSLSocketFactory();
224 	/**
225 	 * Create an instance of TSSLSocket with a fresh new socket.
226 	 */
227 	virtual std::shared_ptr<TSSLSocket> createSocket();
228 	/**
229 	 * Create an instance of TSSLSocket with a fresh new socket, which is interruptable.
230 	 */
231 	virtual std::shared_ptr<TSSLSocket>
232 	createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener);
233 	/**
234 	 * Create an instance of TSSLSocket with the given socket.
235 	 *
236 	 * @param socket An existing socket.
237 	 */
238 	virtual std::shared_ptr<TSSLSocket> createSocket(THRIFT_SOCKET socket);
239 	/**
240 	 * Create an instance of TSSLSocket with the given socket which is interruptable.
241 	 *
242 	 * @param socket An existing socket.
243 	 */
244 	virtual std::shared_ptr<TSSLSocket>
245 	createSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener);
246 	/**
247 	 * Create an instance of TSSLSocket.
248 	 *
249 	 * @param host  Remote host to be connected to
250 	 * @param port  Remote port to be connected to
251 	 */
252 	virtual std::shared_ptr<TSSLSocket> createSocket(const std::string &host, int port);
253 	/**
254 	 * Create an instance of TSSLSocket.
255 	 *
256 	 * @param host  Remote host to be connected to
257 	 * @param port  Remote port to be connected to
258 	 */
259 	virtual std::shared_ptr<TSSLSocket>
260 	createSocket(const std::string &host, int port,
261 		     std::shared_ptr<THRIFT_SOCKET> interruptListener);
262 	/**
263 	 * Set ciphers to be used in SSL handshake process.
264 	 *
265 	 * @param ciphers  A list of ciphers
266 	 */
267 	virtual void ciphers(const std::string &enable);
268 	/**
269 	 * Enable/Disable authentication.
270 	 *
271 	 * @param required Require peer to present valid certificate if true
272 	 */
273 	virtual void authenticate(bool required);
274 	/**
275 	 * Load server certificate.
276 	 *
277 	 * @param path   Path to the certificate file
278 	 * @param format Certificate file format
279 	 */
280 	virtual void loadCertificate(const char *path, const char *format = "PEM");
281 	virtual void loadCertificateFromBuffer(const char *aCertificate,
282 					       const char *format = "PEM");
283 	/**
284 	 * Load private key.
285 	 *
286 	 * @param path   Path to the private key file
287 	 * @param format Private key file format
288 	 */
289 	virtual void loadPrivateKey(const char *path, const char *format = "PEM");
290 	virtual void loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format = "PEM");
291 	/**
292 	 * Load trusted certificates from specified file.
293 	 *
294 	 * @param path Path to trusted certificate file
295 	 */
296 	virtual void loadTrustedCertificates(const char *path, const char *capath = nullptr);
297 	virtual void loadTrustedCertificatesFromBuffer(const char *aCertificate,
298 						       const char *aChain = nullptr);
299 	/**
300 	 * Default randomize method.
301 	 */
302 	virtual void randomize();
303 	/**
304 	 * Override default OpenSSL password callback with getPassword().
305 	 */
306 	void overrideDefaultPasswordCallback();
307 	/**
308 	 * Set/Unset server mode.
309 	 *
310 	 * @param flag  Server mode if true
311 	 */
312 	virtual void server(bool flag);
313 	/**
314 	 * Determine whether the socket is in server or client mode.
315 	 *
316 	 * @return true, if server mode, or, false, if client mode
317 	 */
318 	virtual bool server() const;
319 	/**
320 	 * Set AccessManager.
321 	 *
322 	 * @param manager  The AccessManager instance
323 	 */
access(std::shared_ptr<AccessManager> manager)324 	virtual void access(std::shared_ptr<AccessManager> manager)
325 	{
326 		access_ = manager;
327 	}
setManualOpenSSLInitialization(bool manualOpenSSLInitialization)328 	static void setManualOpenSSLInitialization(bool manualOpenSSLInitialization)
329 	{
330 		manualOpenSSLInitialization_ = manualOpenSSLInitialization;
331 	}
332 
333       protected:
334 	std::shared_ptr<SSLContext> ctx_;
335 
336 	/**
337 	 * Override this method for custom password callback. It may be called
338 	 * multiple times at any time during a session as necessary.
339 	 *
340 	 * @param password Pass collected password to OpenSSL
341 	 * @param size     Maximum length of password including NULL character
342 	 */
getPassword(std::string &,int)343 	virtual void getPassword(std::string & /* password */, int /* size */)
344 	{
345 	}
346 
347       private:
348 	bool server_;
349 	std::shared_ptr<AccessManager> access_;
350 	static concurrency::Mutex mutex_;
351 	static uint64_t count_;
352 	THRIFT_EXPORT static bool manualOpenSSLInitialization_;
353 
354 	void setup(std::shared_ptr<TSSLSocket> ssl);
355 	static int passwordCallback(char *password, int size, int, void *data);
356 };
357 
358 /**
359  * SSL exception.
360  */
361 class TSSLException : public TTransportException
362 {
363       public:
TSSLException(const std::string & message)364 	TSSLException(const std::string &message)
365 		: TTransportException(TTransportException::INTERNAL_ERROR, message)
366 	{
367 	}
368 
what()369 	const char *what() const noexcept override
370 	{
371 		if (message_.empty()) {
372 			return "TSSLException";
373 		} else {
374 			return message_.c_str();
375 		}
376 	}
377 };
378 
379 struct SSLContext {
380 	int verifyMode = TLS_PEER_VERIFY_REQUIRED;
381 	net_ip_protocol_secure protocol = IPPROTO_TLS_1_0;
382 };
383 
384 /**
385  * Callback interface for access control. It's meant to verify the remote host.
386  * It's constructed when application starts and set to TSSLSocketFactory
387  * instance. It's passed onto all TSSLSocket instances created by this factory
388  * object.
389  */
390 class AccessManager
391 {
392       public:
393 	enum Decision {
394 		DENY = -1, // deny access
395 		SKIP = 0,  // cannot make decision, move on to next (if any)
396 		ALLOW = 1  // allow access
397 	};
398 	/**
399 	 * Destructor
400 	 */
401 	virtual ~AccessManager() = default;
402 	/**
403 	 * Determine whether the peer should be granted access or not. It's called
404 	 * once after the SSL handshake completes successfully, before peer certificate
405 	 * is examined.
406 	 *
407 	 * If a valid decision (ALLOW or DENY) is returned, the peer certificate is
408 	 * not to be verified.
409 	 *
410 	 * @param  sa Peer IP address
411 	 * @return True if the peer is trusted, false otherwise
412 	 */
verify(const sockaddr_storage &)413 	virtual Decision verify(const sockaddr_storage & /* sa */) noexcept
414 	{
415 		return DENY;
416 	}
417 	/**
418 	 * Determine whether the peer should be granted access or not. It's called
419 	 * every time a DNS subjectAltName/common name is extracted from peer's
420 	 * certificate.
421 	 *
422 	 * @param  host Client mode: host name returned by TSocket::getHost()
423 	 *              Server mode: host name returned by TSocket::getPeerHost()
424 	 * @param  name SubjectAltName or common name extracted from peer certificate
425 	 * @param  size Length of name
426 	 * @return True if the peer is trusted, false otherwise
427 	 *
428 	 * Note: The "name" parameter may be UTF8 encoded.
429 	 */
verify(const std::string &,const char *,int)430 	virtual Decision verify(const std::string & /* host */, const char * /* name */,
431 				int /* size */) noexcept
432 	{
433 		return DENY;
434 	}
435 	/**
436 	 * Determine whether the peer should be granted access or not. It's called
437 	 * every time an IP subjectAltName is extracted from peer's certificate.
438 	 *
439 	 * @param  sa   Peer IP address retrieved from the underlying socket
440 	 * @param  data IP address extracted from certificate
441 	 * @param  size Length of the IP address
442 	 * @return True if the peer is trusted, false otherwise
443 	 */
verify(const sockaddr_storage &,const char *,int)444 	virtual Decision verify(const sockaddr_storage & /* sa */, const char * /* data */,
445 				int /* size */) noexcept
446 	{
447 		return DENY;
448 	}
449 };
450 
451 typedef AccessManager::Decision Decision;
452 
453 class DefaultClientAccessManager : public AccessManager
454 {
455       public:
456 	// AccessManager interface
457 	Decision verify(const sockaddr_storage &sa) noexcept override;
458 	Decision verify(const std::string &host, const char *name, int size) noexcept override;
459 	Decision verify(const sockaddr_storage &sa, const char *data, int size) noexcept override;
460 };
461 } // namespace transport
462 } // namespace thrift
463 } // namespace apache
464 
465 #endif
466