1 /*
2 * Copyright 2006 Facebook
3 *
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7 #include <cstring>
8 #include <zephyr/net/tls_credentials.h>
9 #include <zephyr/posix/sys/socket.h>
10 #include <zephyr/posix/unistd.h>
11
12 #include <thrift/thrift_export.h>
13 #include <thrift/transport/TSSLServerSocket.h>
14 #include <thrift/transport/TSSLSocket.h>
15 #include <thrift/transport/ThriftTLScertificateType.h>
16
17 #include <thrift/transport/TSocketUtils.h>
18
cast_sockopt(T * v)19 template <class T> inline void *cast_sockopt(T *v)
20 {
21 return reinterpret_cast<void *>(v);
22 }
23
24 void destroyer_of_fine_sockets(THRIFT_SOCKET *ssock);
25
26 namespace apache
27 {
28 namespace thrift
29 {
30 namespace transport
31 {
32
33 /**
34 * SSL server socket implementation.
35 */
TSSLServerSocket(int port,std::shared_ptr<TSSLSocketFactory> factory)36 TSSLServerSocket::TSSLServerSocket(int port, std::shared_ptr<TSSLSocketFactory> factory)
37 : TServerSocket(port), factory_(factory)
38 {
39 factory_->server(true);
40 }
41
TSSLServerSocket(const std::string & address,int port,std::shared_ptr<TSSLSocketFactory> factory)42 TSSLServerSocket::TSSLServerSocket(const std::string &address, int port,
43 std::shared_ptr<TSSLSocketFactory> factory)
44 : TServerSocket(address, port), factory_(factory)
45 {
46 factory_->server(true);
47 }
48
TSSLServerSocket(int port,int sendTimeout,int recvTimeout,std::shared_ptr<TSSLSocketFactory> factory)49 TSSLServerSocket::TSSLServerSocket(int port, int sendTimeout, int recvTimeout,
50 std::shared_ptr<TSSLSocketFactory> factory)
51 : TServerSocket(port, sendTimeout, recvTimeout), factory_(factory)
52 {
53 factory_->server(true);
54 }
55
createSocket(THRIFT_SOCKET client)56 std::shared_ptr<TSocket> TSSLServerSocket::createSocket(THRIFT_SOCKET client)
57 {
58 if (interruptableChildren_) {
59 return factory_->createSocket(client, pChildInterruptSockReader_);
60
61 } else {
62 return factory_->createSocket(client);
63 }
64 }
65
listen()66 void TSSLServerSocket::listen()
67 {
68 THRIFT_SOCKET sv[2];
69 // Create the socket pair used to interrupt
70 if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
71 GlobalOutput.perror("TServerSocket::listen() socketpair() interrupt",
72 THRIFT_GET_SOCKET_ERROR);
73 interruptSockWriter_ = THRIFT_INVALID_SOCKET;
74 interruptSockReader_ = THRIFT_INVALID_SOCKET;
75 } else {
76 interruptSockWriter_ = sv[1];
77 interruptSockReader_ = sv[0];
78 }
79
80 // Create the socket pair used to interrupt all clients
81 if (-1 == THRIFT_SOCKETPAIR(AF_LOCAL, SOCK_STREAM, 0, sv)) {
82 GlobalOutput.perror("TServerSocket::listen() socketpair() childInterrupt",
83 THRIFT_GET_SOCKET_ERROR);
84 childInterruptSockWriter_ = THRIFT_INVALID_SOCKET;
85 pChildInterruptSockReader_.reset();
86 } else {
87 childInterruptSockWriter_ = sv[1];
88 pChildInterruptSockReader_ = std::shared_ptr<THRIFT_SOCKET>(
89 new THRIFT_SOCKET(sv[0]), destroyer_of_fine_sockets);
90 }
91
92 // Validate port number
93 if (port_ < 0 || port_ > 0xFFFF) {
94 throw TTransportException(TTransportException::BAD_ARGS,
95 "Specified port is invalid");
96 }
97
98 // Resolve host:port strings into an iterable of struct addrinfo*
99 AddressResolutionHelper resolved_addresses;
100 try {
101 resolved_addresses.resolve(address_, std::to_string(port_), SOCK_STREAM,
102 AI_PASSIVE | AI_V4MAPPED);
103
104 } catch (const std::system_error &e) {
105 GlobalOutput.printf("getaddrinfo() -> %d; %s", e.code().value(), e.what());
106 close();
107 throw TTransportException(TTransportException::NOT_OPEN,
108 "Could not resolve host for server socket.");
109 }
110
111 // we may want to try to bind more than once, since THRIFT_NO_SOCKET_CACHING doesn't
112 // always seem to work. The client can configure the retry variables.
113 int retries = 0;
114 int errno_copy = 0;
115
116 // -- TCP socket -- //
117
118 auto addr_iter = AddressResolutionHelper::Iter{};
119
120 // Via DNS or somehow else, single hostname can resolve into many addresses.
121 // Results may contain perhaps a mix of IPv4 and IPv6. Here, we iterate
122 // over what system gave us, picking the first address that works.
123 do {
124 if (!addr_iter) {
125 // init + recycle over many retries
126 addr_iter = resolved_addresses.iterate();
127 }
128 auto trybind = *addr_iter++;
129
130 serverSocket_ = socket(trybind->ai_family, trybind->ai_socktype, IPPROTO_TLS_1_2);
131 if (serverSocket_ == -1) {
132 errno_copy = THRIFT_GET_SOCKET_ERROR;
133 continue;
134 }
135
136 _setup_sockopts();
137 _setup_tcp_sockopts();
138
139 static const sec_tag_t sec_tag_list[3] = {
140 Thrift_TLS_CA_CERT_TAG, Thrift_TLS_SERVER_CERT_TAG, Thrift_TLS_PRIVATE_KEY};
141
142 int ret = setsockopt(serverSocket_, SOL_TLS, TLS_SEC_TAG_LIST, sec_tag_list,
143 sizeof(sec_tag_list));
144 if (ret != 0) {
145 throw TTransportException(TTransportException::NOT_OPEN,
146 "set TLS_SEC_TAG_LIST failed");
147 }
148
149 #ifdef IPV6_V6ONLY
150 if (trybind->ai_family == AF_INET6) {
151 int zero = 0;
152 if (-1 == setsockopt(serverSocket_, IPPROTO_IPV6, IPV6_V6ONLY,
153 cast_sockopt(&zero), sizeof(zero))) {
154 GlobalOutput.perror("TServerSocket::listen() IPV6_V6ONLY ",
155 THRIFT_GET_SOCKET_ERROR);
156 }
157 }
158 #endif // #ifdef IPV6_V6ONLY
159
160 if (0 == ::bind(serverSocket_, trybind->ai_addr,
161 static_cast<int>(trybind->ai_addrlen))) {
162 break;
163 }
164 errno_copy = THRIFT_GET_SOCKET_ERROR;
165
166 // use short circuit evaluation here to only sleep if we need to
167 } while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
168
169 // retrieve bind info
170 if (port_ == 0 && retries <= retryLimit_) {
171 struct sockaddr_storage sa;
172 socklen_t len = sizeof(sa);
173 std::memset(&sa, 0, len);
174 if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr *>(&sa), &len) <
175 0) {
176 errno_copy = THRIFT_GET_SOCKET_ERROR;
177 GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
178 } else {
179 if (sa.ss_family == AF_INET6) {
180 const auto *sin =
181 reinterpret_cast<const struct sockaddr_in6 *>(&sa);
182 port_ = ntohs(sin->sin6_port);
183 } else {
184 const auto *sin = reinterpret_cast<const struct sockaddr_in *>(&sa);
185 port_ = ntohs(sin->sin_port);
186 }
187 }
188 }
189
190 // throw error if socket still wasn't created successfully
191 if (serverSocket_ == THRIFT_INVALID_SOCKET) {
192 GlobalOutput.perror("TServerSocket::listen() socket() ", errno_copy);
193 close();
194 throw TTransportException(TTransportException::NOT_OPEN,
195 "Could not create server socket.", errno_copy);
196 }
197
198 // throw an error if we failed to bind properly
199 if (retries > retryLimit_) {
200 char errbuf[1024];
201
202 THRIFT_SNPRINTF(errbuf, sizeof(errbuf),
203 "TServerSocket::listen() Could not bind to port %d", port_);
204
205 GlobalOutput(errbuf);
206 close();
207 throw TTransportException(TTransportException::NOT_OPEN, "Could not bind",
208 errno_copy);
209 }
210
211 if (listenCallback_) {
212 listenCallback_(serverSocket_);
213 }
214
215 // Call listen
216 if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
217 errno_copy = THRIFT_GET_SOCKET_ERROR;
218 GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
219 close();
220 throw TTransportException(TTransportException::NOT_OPEN, "Could not listen",
221 errno_copy);
222 }
223
224 // The socket is now listening!
225 listening_ = true;
226 }
227
close()228 void TSSLServerSocket::close()
229 {
230 rwMutex_.lock();
231 if (pChildInterruptSockReader_ != nullptr &&
232 *pChildInterruptSockReader_ != THRIFT_INVALID_SOCKET) {
233 ::THRIFT_CLOSESOCKET(*pChildInterruptSockReader_);
234 *pChildInterruptSockReader_ = THRIFT_INVALID_SOCKET;
235 }
236
237 rwMutex_.unlock();
238
239 TServerSocket::close();
240 }
241
242 } // namespace transport
243 } // namespace thrift
244 } // namespace apache
245