1 /*
2  * Copyright 2006 Facebook
3  *
4  * SPDX-License-Identifier: Apache-2.0
5  */
6 
7 #include <cstring>
8 #include <zephyr/net/tls_credentials.h>
9 #include <zephyr/posix/sys/socket.h>
10 #include <zephyr/posix/unistd.h>
11 
12 #include <thrift/thrift_export.h>
13 #include <thrift/transport/TSSLServerSocket.h>
14 #include <thrift/transport/TSSLSocket.h>
15 #include <thrift/transport/ThriftTLScertificateType.h>
16 
17 #include <thrift/transport/TSocketUtils.h>
18 
cast_sockopt(T * v)19 template <class T> inline void *cast_sockopt(T *v)
20 {
21 	return reinterpret_cast<void *>(v);
22 }
23 
24 void destroyer_of_fine_sockets(THRIFT_SOCKET *ssock);
25 
26 namespace apache
27 {
28 namespace thrift
29 {
30 namespace transport
31 {
32 
33 /**
34  * SSL server socket implementation.
35  */
TSSLServerSocket(int port,std::shared_ptr<TSSLSocketFactory> factory)36 TSSLServerSocket::TSSLServerSocket(int port, std::shared_ptr<TSSLSocketFactory> factory)
37 	: TServerSocket(port), factory_(factory)
38 {
39 	factory_->server(true);
40 }
41 
TSSLServerSocket(const std::string & address,int port,std::shared_ptr<TSSLSocketFactory> factory)42 TSSLServerSocket::TSSLServerSocket(const std::string &address, int port,
43 				   std::shared_ptr<TSSLSocketFactory> factory)
44 	: TServerSocket(address, port), factory_(factory)
45 {
46 	factory_->server(true);
47 }
48 
TSSLServerSocket(int port,int sendTimeout,int recvTimeout,std::shared_ptr<TSSLSocketFactory> factory)49 TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
50 				   std::shared_ptr<TSSLSocketFactory> factory)
51 	: TServerSocket(port, sendTimeout, recvTimeout), factory_(factory)
52 {
53 	factory_->server(true);
54 }
55 
createSocket(THRIFT_SOCKET client)56 std::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client)
57 {
58 	if (interruptableChildren_) {
59 		return factory_->createSocket(client, pChildInterruptSockReader_);
60 
61 	} else {
62 		return factory_->createSocket(client);
63 	}
64 }
65 
listen()66 void TSSLServerSocket::listen()
67 {
68 	THRIFT_SOCKET sv[2];
69 	// Create the socket pair used to interrupt
70 	if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
71 		GlobalOutput.perror("TServerSocket::listen() socketpair() interrupt",
72 				    THRIFT_GET_SOCKET_ERROR);
73 		interruptSockWriter_ = THRIFT_INVALID_SOCKET;
74 		interruptSockReader_ = THRIFT_INVALID_SOCKET;
75 	} else {
76 		interruptSockWriter_ = sv[1];
77 		interruptSockReader_ = sv[0];
78 	}
79 
80 	// Create the socket pair used to interrupt all clients
81 	if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
82 		GlobalOutput.perror("TServerSocket::listen() socketpair() childInterrupt",
83 				    THRIFT_GET_SOCKET_ERROR);
84 		childInterruptSockWriter_ = THRIFT_INVALID_SOCKET;
85 		pChildInterruptSockReader_.reset();
86 	} else {
87 		childInterruptSockWriter_ = sv[1];
88 		pChildInterruptSockReader_ = std::shared_ptr<THRIFT_SOCKET>(
89 			new THRIFT_SOCKET(sv[0]), destroyer_of_fine_sockets);
90 	}
91 
92 	// Validate port number
93 	if (port_ < 0 || port_ > 0xFFFF) {
94 		throw TTransportException(TTransportException::BAD_ARGS,
95 					  "Specified port is invalid");
96 	}
97 
98 	// Resolve host:port strings into an iterable of struct addrinfo*
99 	AddressResolutionHelper resolved_addresses;
100 	try {
101 		resolved_addresses.resolve(address_, std::to_string(port_), SOCK_STREAM,
102 					   AI_PASSIVE | AI_V4MAPPED);
103 
104 	} catch (const std::system_error &e) {
105 		GlobalOutput.printf("getaddrinfo() -> %d; %s", e.code().value(), e.what());
106 		close();
107 		throw TTransportException(TTransportException::NOT_OPEN,
108 					  "Could not resolve host for server socket.");
109 	}
110 
111 	// we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't
112 	// always seem to work. The client can configure the retry variables.
113 	int retries = 0;
114 	int errno_copy = 0;
115 
116 	// -- TCP socket -- //
117 
118 	auto addr_iter = AddressResolutionHelper::Iter{};
119 
120 	// Via DNS or somehow else, single hostname can resolve into many addresses.
121 	// Results may contain perhaps a mix of IPv4 and IPv6.  Here, we iterate
122 	// over what system gave us, picking the first address that works.
123 	do {
124 		if (!addr_iter) {
125 			// init + recycle over many retries
126 			addr_iter = resolved_addresses.iterate();
127 		}
128 		auto trybind = *addr_iter++;
129 
130 		serverSocket_ = socket(trybind->ai_family, trybind->ai_socktype, IPPROTO_TLS_1_2);
131 		if (serverSocket_ == -1) {
132 			errno_copy = THRIFT_GET_SOCKET_ERROR;
133 			continue;
134 		}
135 
136 		_setup_sockopts();
137 		_setup_tcp_sockopts();
138 
139 		static const sec_tag_t sec_tag_list[3] = {
140 			Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY};
141 
142 		int ret = setsockopt(serverSocket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list,
143 				     sizeof(sec_tag_list));
144 		if (ret != 0) {
145 			throw TTransportException(TTransportException::NOT_OPEN,
146 						  "set TLS_SEC_TAG_LIST failed");
147 		}
148 
149 #ifdef IPV6_V6ONLY
150 		if (trybind->ai_family == AF_INET6) {
151 			int zero = 0;
152 			if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY,
153 					     cast_sockopt(&zero), sizeof(zero))) {
154 				GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ",
155 						    THRIFT_GET_SOCKET_ERROR);
156 			}
157 		}
158 #endif // #ifdef IPV6_V6ONLY
159 
160 		if (0 == ::bind(serverSocket_, trybind->ai_addr,
161 				static_cast<int>(trybind->ai_addrlen))) {
162 			break;
163 		}
164 		errno_copy = THRIFT_GET_SOCKET_ERROR;
165 
166 		// use short circuit evaluation here to only sleep if we need to
167 	} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
168 
169 	// retrieve bind info
170 	if (port_ == 0 && retries <= retryLimit_) {
171 		struct sockaddr_storage sa;
172 		socklen_t len = sizeof(sa);
173 		std::memset(&sa, 0, len);
174 		if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr *>(&sa), &len) <
175 		    0) {
176 			errno_copy = THRIFT_GET_SOCKET_ERROR;
177 			GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
178 		} else {
179 			if (sa.ss_family == AF_INET6) {
180 				const auto *sin =
181 					reinterpret_cast<const struct sockaddr_in6 *>(&sa);
182 				port_ = ntohs(sin->sin6_port);
183 			} else {
184 				const auto *sin = reinterpret_cast<const struct sockaddr_in *>(&sa);
185 				port_ = ntohs(sin->sin_port);
186 			}
187 		}
188 	}
189 
190 	// throw error if socket still wasn't created successfully
191 	if (serverSocket_ == THRIFT_INVALID_SOCKET) {
192 		GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy);
193 		close();
194 		throw TTransportException(TTransportException::NOT_OPEN,
195 					  "Could not create server socket.", errno_copy);
196 	}
197 
198 	// throw an error if we failed to bind properly
199 	if (retries > retryLimit_) {
200 		char errbuf[1024];
201 
202 		THRIFT_SNPRINTF(errbuf, sizeof(errbuf),
203 				"TServerSocket::listen() Could not bind to port %d", port_);
204 
205 		GlobalOutput(errbuf);
206 		close();
207 		throw TTransportException(TTransportException::NOT_OPEN, "Could not bind",
208 					  errno_copy);
209 	}
210 
211 	if (listenCallback_) {
212 		listenCallback_(serverSocket_);
213 	}
214 
215 	// Call listen
216 	if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
217 		errno_copy = THRIFT_GET_SOCKET_ERROR;
218 		GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
219 		close();
220 		throw TTransportException(TTransportException::NOT_OPEN, "Could not listen",
221 					  errno_copy);
222 	}
223 
224 	// The socket is now listening!
225 	listening_ = true;
226 }
227 
close()228 void TSSLServerSocket::close()
229 {
230 	rwMutex_.lock();
231 	if (pChildInterruptSockReader_ != nullptr &&
232 	    *pChildInterruptSockReader_ != THRIFT_INVALID_SOCKET) {
233 		::THRIFT_CLOSESOCKET(*pChildInterruptSockReader_);
234 		*pChildInterruptSockReader_ = THRIFT_INVALID_SOCKET;
235 	}
236 
237 	rwMutex_.unlock();
238 
239 	TServerSocket::close();
240 }
241 
242 } // namespace transport
243 } // namespace thrift
244 } // namespace apache
245