1 /*
2  * Copyright 2006 Facebook
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <thrift/thrift-config.h>
8 
9 #include <cstring>
10 #include <errno.h>
11 #include <memory>
12 #include <string>
13 #ifdef HAVE_ARPA_INET_H
14 #include <zephyr/posix/arpa/inet.h>
15 #endif
16 #include <sys/types.h>
17 #ifdef HAVE_POLL_H
18 #include <poll.h>
19 #endif
20 
21 #include <zephyr/net/tls_credentials.h>
22 
23 #include <fcntl.h>
24 
25 #include <thrift/TToString.h>
26 #include <thrift/concurrency/Mutex.h>
27 #include <thrift/transport/PlatformSocket.h>
28 #include <thrift/transport/TSSLSocket.h>
29 #include <thrift/transport/ThriftTLScertificateType.h>
30 
31 using namespace apache::thrift::concurrency;
32 using std::string;
33 
34 struct CRYPTO_dynlock_value {
35 	Mutex mutex;
36 };
37 
38 namespace apache
39 {
40 namespace thrift
41 {
42 namespace transport
43 {
44 
45 static bool matchName(const char *host, const char *pattern, int size);
46 static char uppercase(char c);
47 
48 // TSSLSocket implementation
TSSLSocket(std::shared_ptr<SSLContext> ctx,std::shared_ptr<TConfiguration> config)49 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<TConfiguration> config)
50 	: TSocket(config), server_(false), ctx_(ctx)
51 {
52 	init();
53 }
54 
TSSLSocket(std::shared_ptr<SSLContext> ctx,std::shared_ptr<THRIFT_SOCKET> interruptListener,std::shared_ptr<TConfiguration> config)55 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx,
56 		       std::shared_ptr<THRIFT_SOCKET> interruptListener,
57 		       std::shared_ptr<TConfiguration> config)
58 	: TSocket(config), server_(false), ctx_(ctx)
59 {
60 	init();
61 	interruptListener_ = interruptListener;
62 }
63 
TSSLSocket(std::shared_ptr<SSLContext> ctx,THRIFT_SOCKET socket,std::shared_ptr<TConfiguration> config)64 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
65 		       std::shared_ptr<TConfiguration> config)
66 	: TSocket(socket, config), server_(false), ctx_(ctx)
67 {
68 	init();
69 }
70 
TSSLSocket(std::shared_ptr<SSLContext> ctx,THRIFT_SOCKET socket,std::shared_ptr<THRIFT_SOCKET> interruptListener,std::shared_ptr<TConfiguration> config)71 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket,
72 		       std::shared_ptr<THRIFT_SOCKET> interruptListener,
73 		       std::shared_ptr<TConfiguration> config)
74 	: TSocket(socket, interruptListener, config), server_(false), ctx_(ctx)
75 {
76 	init();
77 }
78 
TSSLSocket(std::shared_ptr<SSLContext> ctx,string host,int port,std::shared_ptr<TConfiguration> config)79 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port,
80 		       std::shared_ptr<TConfiguration> config)
81 	: TSocket(host, port, config), server_(false), ctx_(ctx)
82 {
83 	init();
84 }
85 
TSSLSocket(std::shared_ptr<SSLContext> ctx,string host,int port,std::shared_ptr<THRIFT_SOCKET> interruptListener,std::shared_ptr<TConfiguration> config)86 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port,
87 		       std::shared_ptr<THRIFT_SOCKET> interruptListener,
88 		       std::shared_ptr<TConfiguration> config)
89 	: TSocket(host, port, config), server_(false), ctx_(ctx)
90 {
91 	init();
92 	interruptListener_ = interruptListener;
93 }
94 
~TSSLSocket()95 TSSLSocket::~TSSLSocket()
96 {
97 	close();
98 }
99 
cast_sockopt(T * v)100 template <class T> inline void *cast_sockopt(T *v)
101 {
102 	return reinterpret_cast<void *>(v);
103 }
104 
authorize()105 void TSSLSocket::authorize()
106 {
107 }
108 
openSecConnection(struct addrinfo * res)109 void TSSLSocket::openSecConnection(struct addrinfo *res)
110 {
111 	socket_ = socket(res->ai_family, res->ai_socktype, ctx_->protocol);
112 
113 	if (socket_ == THRIFT_INVALID_SOCKET) {
114 		int errno_copy = THRIFT_GET_SOCKET_ERROR;
115 		GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy);
116 		throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy);
117 	}
118 
119 	static const sec_tag_t sec_tag_list[3] = {
120 		Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY};
121 
122 	int ret =
123 		setsockopt(socket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list, sizeof(sec_tag_list));
124 	if (ret != 0) {
125 		throw TTransportException(TTransportException::NOT_OPEN,
126 					  "set TLS_SEC_TAG_LIST failed");
127 	}
128 
129 	ret = setsockopt(socket_, SOL_TLS, TLS_PEER_VERIFY, &(ctx_->verifyMode),
130 			 sizeof(ctx_->verifyMode));
131 	if (ret != 0) {
132 		throw TTransportException(TTransportException::NOT_OPEN,
133 					  "set TLS_PEER_VERIFY failed");
134 	}
135 
136 	ret = setsockopt(socket_, SOL_TLS, TLS_HOSTNAME, host_.c_str(), host_.size());
137 	if (ret != 0) {
138 		throw TTransportException(TTransportException::NOT_OPEN, "set TLS_HOSTNAME failed");
139 	}
140 
141 	// Send timeout
142 	if (sendTimeout_ > 0) {
143 		setSendTimeout(sendTimeout_);
144 	}
145 
146 	// Recv timeout
147 	if (recvTimeout_ > 0) {
148 		setRecvTimeout(recvTimeout_);
149 	}
150 
151 	if (keepAlive_) {
152 		setKeepAlive(keepAlive_);
153 	}
154 
155 	// Linger
156 	setLinger(lingerOn_, lingerVal_);
157 
158 	// No delay
159 	setNoDelay(noDelay_);
160 
161 #ifdef SO_NOSIGPIPE
162 	{
163 		int one = 1;
164 		setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
165 	}
166 #endif
167 
168 // Uses a low min RTO if asked to.
169 #ifdef TCP_LOW_MIN_RTO
170 	if (getUseLowMinRto()) {
171 		int one = 1;
172 		setsockopt(socket_, IPPROTO_TCP, TCP_LOW_MIN_RTO, &one, sizeof(one));
173 	}
174 #endif
175 
176 	// Set the socket to be non blocking for connect if a timeout exists
177 	int flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0);
178 	if (connTimeout_ > 0) {
179 		if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
180 			int errno_copy = THRIFT_GET_SOCKET_ERROR;
181 			GlobalOutput.perror("TSocket::open() THRIFT_FCNTL() " + getSocketInfo(),
182 					    errno_copy);
183 			throw TTransportException(TTransportException::NOT_OPEN,
184 						  "THRIFT_FCNTL() failed", errno_copy);
185 		}
186 	} else {
187 		if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags & ~THRIFT_O_NONBLOCK)) {
188 			int errno_copy = THRIFT_GET_SOCKET_ERROR;
189 			GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(),
190 					    errno_copy);
191 			throw TTransportException(TTransportException::NOT_OPEN,
192 						  "THRIFT_FCNTL() failed", errno_copy);
193 		}
194 	}
195 
196 	// Connect the socket
197 
198 	ret = connect(socket_, res->ai_addr, static_cast<int>(res->ai_addrlen));
199 
200 	// success case
201 	if (ret == 0) {
202 		goto done;
203 	}
204 
205 	if ((THRIFT_GET_SOCKET_ERROR != THRIFT_EINPROGRESS) &&
206 	    (THRIFT_GET_SOCKET_ERROR != THRIFT_EWOULDBLOCK)) {
207 		int errno_copy = THRIFT_GET_SOCKET_ERROR;
208 		GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy);
209 		throw TTransportException(TTransportException::NOT_OPEN, "connect() failed",
210 					  errno_copy);
211 	}
212 
213 	struct THRIFT_POLLFD fds[1];
214 	std::memset(fds, 0, sizeof(fds));
215 	fds[0].fd = socket_;
216 	fds[0].events = THRIFT_POLLOUT;
217 	ret = THRIFT_POLL(fds, 1, connTimeout_);
218 
219 	if (ret > 0) {
220 		// Ensure the socket is connected and that there are no errors set
221 		int val;
222 		socklen_t lon;
223 		lon = sizeof(int);
224 		int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, cast_sockopt(&val), &lon);
225 		if (ret2 == -1) {
226 			int errno_copy = THRIFT_GET_SOCKET_ERROR;
227 			GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(),
228 					    errno_copy);
229 			throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()",
230 						  errno_copy);
231 		}
232 		// no errors on socket, go to town
233 		if (val == 0) {
234 			goto done;
235 		}
236 		GlobalOutput.perror("TSocket::open() error on socket (after THRIFT_POLL) " +
237 					    getSocketInfo(),
238 				    val);
239 		throw TTransportException(TTransportException::NOT_OPEN, "socket open() error",
240 					  val);
241 	} else if (ret == 0) {
242 		// socket timed out
243 		string errStr = "TSocket::open() timed out " + getSocketInfo();
244 		GlobalOutput(errStr.c_str());
245 		throw TTransportException(TTransportException::NOT_OPEN, "open() timed out");
246 	} else {
247 		// error on THRIFT_POLL()
248 		int errno_copy = THRIFT_GET_SOCKET_ERROR;
249 		GlobalOutput.perror("TSocket::open() THRIFT_POLL() " + getSocketInfo(), errno_copy);
250 		throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_POLL() failed",
251 					  errno_copy);
252 	}
253 
254 done:
255 	// Set socket back to normal mode (blocking)
256 	if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags)) {
257 		int errno_copy = THRIFT_GET_SOCKET_ERROR;
258 		GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), errno_copy);
259 		throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed",
260 					  errno_copy);
261 	}
262 
263 	setCachedAddress(res->ai_addr, static_cast<socklen_t>(res->ai_addrlen));
264 }
265 
init()266 void TSSLSocket::init()
267 {
268 	handshakeCompleted_ = false;
269 	readRetryCount_ = 0;
270 	eventSafe_ = false;
271 }
272 
open()273 void TSSLSocket::open()
274 {
275 	if (isOpen() || server()) {
276 		throw TTransportException(TTransportException::BAD_ARGS);
277 	}
278 
279 	// Validate port number
280 	if (port_ < 0 || port_ > 0xFFFF) {
281 		throw TTransportException(TTransportException::BAD_ARGS,
282 					  "Specified port is invalid");
283 	}
284 
285 	struct addrinfo hints, *res, *res0;
286 	res = nullptr;
287 	res0 = nullptr;
288 	int error;
289 	char port[sizeof("65535")];
290 	std::memset(&hints, 0, sizeof(hints));
291 	hints.ai_family = PF_UNSPEC;
292 	hints.ai_socktype = SOCK_STREAM;
293 	hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
294 	sprintf(port, "%d", port_);
295 
296 	error = getaddrinfo(host_.c_str(), port, &hints, &res0);
297 
298 	if (error == DNS_EAI_NODATA) {
299 		hints.ai_flags &= ~AI_ADDRCONFIG;
300 		error = getaddrinfo(host_.c_str(), port, &hints, &res0);
301 	}
302 
303 	if (error) {
304 		string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo() +
305 				string(THRIFT_GAI_STRERROR(error));
306 		GlobalOutput(errStr.c_str());
307 		close();
308 		throw TTransportException(TTransportException::NOT_OPEN,
309 					  "Could not resolve host for client socket.");
310 	}
311 
312 	// Cycle through all the returned addresses until one
313 	// connects or push the exception up.
314 	for (res = res0; res; res = res->ai_next) {
315 		try {
316 			openSecConnection(res);
317 			break;
318 		} catch (TTransportException &) {
319 			if (res->ai_next) {
320 				close();
321 			} else {
322 				close();
323 				freeaddrinfo(res0); // cleanup on failure
324 				throw;
325 			}
326 		}
327 	}
328 
329 	// Free address structure memory
330 	freeaddrinfo(res0);
331 }
332 
TSSLSocketFactory(SSLProtocol protocol)333 TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol)
334 	: ctx_(std::make_shared<SSLContext>()), server_(false)
335 {
336 	switch (protocol) {
337 	case SSLTLS:
338 		break;
339 	case TLSv1_0:
340 		break;
341 	case TLSv1_1:
342 		ctx_->protocol = IPPROTO_TLS_1_1;
343 		break;
344 	case TLSv1_2:
345 		ctx_->protocol = IPPROTO_TLS_1_2;
346 		break;
347 	default:
348 		throw TTransportException(TTransportException::BAD_ARGS,
349 					  "Specified protocol is invalid");
350 	}
351 }
352 
~TSSLSocketFactory()353 TSSLSocketFactory::~TSSLSocketFactory()
354 {
355 }
356 
createSocket()357 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket()
358 {
359 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
360 	setup(ssl);
361 	return ssl;
362 }
363 
364 std::shared_ptr<TSSLSocket>
createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener)365 TSSLSocketFactory::createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener)
366 {
367 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
368 	setup(ssl);
369 	return ssl;
370 }
371 
createSocket(THRIFT_SOCKET socket)372 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket)
373 {
374 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
375 	setup(ssl);
376 	return ssl;
377 }
378 
379 std::shared_ptr<TSSLSocket>
createSocket(THRIFT_SOCKET socket,std::shared_ptr<THRIFT_SOCKET> interruptListener)380 TSSLSocketFactory::createSocket(THRIFT_SOCKET socket,
381 				std::shared_ptr<THRIFT_SOCKET> interruptListener)
382 {
383 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
384 	setup(ssl);
385 	return ssl;
386 }
387 
createSocket(const string & host,int port)388 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string &host, int port)
389 {
390 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
391 	setup(ssl);
392 	return ssl;
393 }
394 
395 std::shared_ptr<TSSLSocket>
createSocket(const string & host,int port,std::shared_ptr<THRIFT_SOCKET> interruptListener)396 TSSLSocketFactory::createSocket(const string &host, int port,
397 				std::shared_ptr<THRIFT_SOCKET> interruptListener)
398 {
399 	std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
400 	setup(ssl);
401 	return ssl;
402 }
403 
404 static void tlsCredtErrMsg(string &errors, const int status);
405 
setup(std::shared_ptr<TSSLSocket> ssl)406 void TSSLSocketFactory::setup(std::shared_ptr<TSSLSocket> ssl)
407 {
408 	ssl->server(server());
409 	if (access_ == nullptr && !server()) {
410 		access_ = std::shared_ptr<AccessManager>(new DefaultClientAccessManager);
411 	}
412 	if (access_ != nullptr) {
413 		ssl->access(access_);
414 	}
415 }
416 
ciphers(const string & enable)417 void TSSLSocketFactory::ciphers(const string &enable)
418 {
419 }
420 
authenticate(bool required)421 void TSSLSocketFactory::authenticate(bool required)
422 {
423 	if (required) {
424 		ctx_->verifyMode = TLS_PEER_VERIFY_REQUIRED;
425 	} else {
426 		ctx_->verifyMode = TLS_PEER_VERIFY_NONE;
427 	}
428 }
429 
loadCertificate(const char * path,const char * format)430 void TSSLSocketFactory::loadCertificate(const char *path, const char *format)
431 {
432 	if (path == nullptr || format == nullptr) {
433 		throw TTransportException(
434 			TTransportException::BAD_ARGS,
435 			"loadCertificateChain: either <path> or <format> is nullptr");
436 	}
437 	if (strcmp(format, "PEM") == 0) {
438 
439 	} else {
440 		throw TSSLException("Unsupported certificate format: " + string(format));
441 	}
442 }
443 
loadCertificateFromBuffer(const char * aCertificate,const char * format)444 void TSSLSocketFactory::loadCertificateFromBuffer(const char *aCertificate, const char *format)
445 {
446 	if (aCertificate == nullptr || format == nullptr) {
447 		throw TTransportException(TTransportException::BAD_ARGS,
448 					  "loadCertificate: either <path> or <format> is nullptr");
449 	}
450 
451 	if (strcmp(format, "PEM") == 0) {
452 		const int status = tls_credential_add(Thrift_TLS_SERVER_CERT_TAG,
453 						      TLS_CREDENTIAL_SERVER_CERTIFICATE,
454 						      aCertificate, strlen(aCertificate) + 1);
455 
456 		if (status != 0) {
457 			string errors;
458 			tlsCredtErrMsg(errors, status);
459 			throw TSSLException("tls_credential_add: " + errors);
460 		}
461 	} else {
462 		throw TSSLException("Unsupported certificate format: " + string(format));
463 	}
464 }
465 
loadPrivateKey(const char * path,const char * format)466 void TSSLSocketFactory::loadPrivateKey(const char *path, const char *format)
467 {
468 	if (path == nullptr || format == nullptr) {
469 		throw TTransportException(TTransportException::BAD_ARGS,
470 					  "loadPrivateKey: either <path> or <format> is nullptr");
471 	}
472 	if (strcmp(format, "PEM") == 0) {
473 		if (0) {
474 			string errors;
475 			// tlsCredtErrMsg(errors, status);
476 			throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
477 		}
478 	}
479 }
480 
loadPrivateKeyFromBuffer(const char * aPrivateKey,const char * format)481 void TSSLSocketFactory::loadPrivateKeyFromBuffer(const char *aPrivateKey, const char *format)
482 {
483 	if (aPrivateKey == nullptr || format == nullptr) {
484 		throw TTransportException(TTransportException::BAD_ARGS,
485 					  "loadPrivateKey: either <path> or <format> is nullptr");
486 	}
487 	if (strcmp(format, "PEM") == 0) {
488 		const int status =
489 			tls_credential_add(Thrift_TLS_PRIVATE_KEY, TLS_CREDENTIAL_PRIVATE_KEY,
490 					   aPrivateKey, strlen(aPrivateKey) + 1);
491 
492 		if (status != 0) {
493 			string errors;
494 			tlsCredtErrMsg(errors, status);
495 			throw TSSLException("SSL_CTX_use_PrivateKey: " + errors);
496 		}
497 	} else {
498 		throw TSSLException("Unsupported certificate format: " + string(format));
499 	}
500 }
501 
loadTrustedCertificates(const char * path,const char * capath)502 void TSSLSocketFactory::loadTrustedCertificates(const char *path, const char *capath)
503 {
504 	if (path == nullptr) {
505 		throw TTransportException(TTransportException::BAD_ARGS,
506 					  "loadTrustedCertificates: <path> is nullptr");
507 	}
508 	if (0) {
509 		string errors;
510 		// tlsCredtErrMsg(errors, status);
511 		throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
512 	}
513 }
514 
loadTrustedCertificatesFromBuffer(const char * aCertificate,const char * aChain)515 void TSSLSocketFactory::loadTrustedCertificatesFromBuffer(const char *aCertificate,
516 							  const char *aChain)
517 {
518 	if (aCertificate == nullptr) {
519 		throw TTransportException(TTransportException::BAD_ARGS,
520 					  "loadTrustedCertificates: aCertificate is empty");
521 	}
522 	const int status = tls_credential_add(Thrift_TLS_CA_CERT_TAG, TLS_CREDENTIAL_CA_CERTIFICATE,
523 					      aCertificate, strlen(aCertificate) + 1);
524 
525 	if (status != 0) {
526 		string errors;
527 		tlsCredtErrMsg(errors, status);
528 		throw TSSLException("X509_STORE_add_cert: " + errors);
529 	}
530 
531 	if (aChain) {
532 	}
533 }
534 
randomize()535 void TSSLSocketFactory::randomize()
536 {
537 }
538 
overrideDefaultPasswordCallback()539 void TSSLSocketFactory::overrideDefaultPasswordCallback()
540 {
541 }
542 
server(bool flag)543 void TSSLSocketFactory::server(bool flag)
544 {
545 	server_ = flag;
546 	ctx_->verifyMode = TLS_PEER_VERIFY_NONE;
547 }
548 
server() const549 bool TSSLSocketFactory::server() const
550 {
551 	return server_;
552 }
553 
passwordCallback(char * password,int size,int,void * data)554 int TSSLSocketFactory::passwordCallback(char *password, int size, int, void *data)
555 {
556 	auto *factory = (TSSLSocketFactory *)data;
557 	string userPassword;
558 	factory->getPassword(userPassword, size);
559 	int length = static_cast<int>(userPassword.size());
560 	if (length > size) {
561 		length = size;
562 	}
563 	strncpy(password, userPassword.c_str(), length);
564 	userPassword.assign(userPassword.size(), '*');
565 	return length;
566 }
567 
568 // extract error messages from error queue
tlsCredtErrMsg(string & errors,const int status)569 static void tlsCredtErrMsg(string &errors, const int status)
570 {
571 	if (status == EACCES) {
572 		errors = "Access to the TLS credential subsystem was denied";
573 	} else if (status == ENOMEM) {
574 		errors = "Not enough memory to add new TLS credential";
575 	} else if (status == EEXIST) {
576 		errors = "TLS credential of specific tag and type already exists";
577 	} else {
578 		errors = "Unknown error";
579 	}
580 }
581 
582 /**
583  * Default implementation of AccessManager
584  */
verify(const sockaddr_storage & sa)585 Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa) noexcept
586 {
587 	(void)sa;
588 	return SKIP;
589 }
590 
verify(const string & host,const char * name,int size)591 Decision DefaultClientAccessManager::verify(const string &host, const char *name, int size) noexcept
592 {
593 	if (host.empty() || name == nullptr || size <= 0) {
594 		return SKIP;
595 	}
596 	return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
597 }
598 
verify(const sockaddr_storage & sa,const char * data,int size)599 Decision DefaultClientAccessManager::verify(const sockaddr_storage &sa, const char *data,
600 					    int size) noexcept
601 {
602 	bool match = false;
603 	if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
604 		match = (memcmp(&((sockaddr_in *)&sa)->sin_addr, data, size) == 0);
605 	} else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
606 		match = (memcmp(&((sockaddr_in6 *)&sa)->sin6_addr, data, size) == 0);
607 	}
608 	return (match ? ALLOW : SKIP);
609 }
610 
611 /**
612  * Match a name with a pattern. The pattern may include wildcard. A single
613  * wildcard "*" can match up to one component in the domain name.
614  *
615  * @param  host    Host name, typically the name of the remote host
616  * @param  pattern Name retrieved from certificate
617  * @param  size    Size of "pattern"
618  * @return True, if "host" matches "pattern". False otherwise.
619  */
matchName(const char * host,const char * pattern,int size)620 bool matchName(const char *host, const char *pattern, int size)
621 {
622 	bool match = false;
623 	int i = 0, j = 0;
624 	while (i < size && host[j] != '\0') {
625 		if (uppercase(pattern[i]) == uppercase(host[j])) {
626 			i++;
627 			j++;
628 			continue;
629 		}
630 		if (pattern[i] == '*') {
631 			while (host[j] != '.' && host[j] != '\0') {
632 				j++;
633 			}
634 			i++;
635 			continue;
636 		}
637 		break;
638 	}
639 	if (i == size && host[j] == '\0') {
640 		match = true;
641 	}
642 	return match;
643 }
644 
645 // This is to work around the Turkish locale issue, i.e.,
646 // toupper('i') != toupper('I') if locale is "tr_TR"
uppercase(char c)647 char uppercase(char c)
648 {
649 	if ('a' <= c && c <= 'z') {
650 		return c + ('A' - 'a');
651 	}
652 	return c;
653 }
654 } // namespace transport
655 } // namespace thrift
656 } // namespace apache
657