1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <thrift/thrift-config.h>
21 
22 #include <cstring>
23 #include <sstream>
24 #ifdef HAVE_SYS_IOCTL_H
25 #include <sys/ioctl.h>
26 #ifdef __sun
27 #include <sys/filio.h>
28 #endif // __sun
29 #endif
30 #ifdef HAVE_SYS_SOCKET_H
31 #include <sys/socket.h>
32 #endif
33 #ifdef HAVE_SYS_UN_H
34 #include <sys/un.h>
35 #endif
36 #ifdef HAVE_POLL_H
37 #include <poll.h>
38 #endif
39 #ifdef HAVE_SYS_POLL_H
40 #include <sys/poll.h>
41 #endif
42 #include <sys/types.h>
43 #ifdef HAVE_NETINET_IN_H
44 #include <netinet/in.h>
45 #include <netinet/tcp.h>
46 #endif
47 #ifdef HAVE_UNISTD_H
48 #include <unistd.h>
49 #endif
50 #include <fcntl.h>
51 
52 #include <thrift/concurrency/Monitor.h>
53 #include <thrift/transport/TSocket.h>
54 #include <thrift/transport/TTransportException.h>
55 #include <thrift/transport/PlatformSocket.h>
56 #include <thrift/transport/SocketCommon.h>
57 
58 #ifndef SOCKOPT_CAST_T
59 #ifndef _WIN32
60 #define SOCKOPT_CAST_T void
61 #else
62 #define SOCKOPT_CAST_T char
63 #endif // _WIN32
64 #endif
65 
66 template <class T>
const_cast_sockopt(const T * v)67 inline const SOCKOPT_CAST_T* const_cast_sockopt(const T* v) {
68   return reinterpret_cast<const SOCKOPT_CAST_T*>(v);
69 }
70 
71 template <class T>
cast_sockopt(T * v)72 inline SOCKOPT_CAST_T* cast_sockopt(T* v) {
73   return reinterpret_cast<SOCKOPT_CAST_T*>(v);
74 }
75 
76 using std::string;
77 
78 namespace apache {
79 namespace thrift {
80 namespace transport {
81 
82 /**
83  * TSocket implementation.
84  *
85  */
86 
TSocket(const string & host,int port,std::shared_ptr<TConfiguration> config)87 TSocket::TSocket(const string& host, int port, std::shared_ptr<TConfiguration> config)
88   : TVirtualTransport(config),
89     host_(host),
90     port_(port),
91     socket_(THRIFT_INVALID_SOCKET),
92     peerPort_(0),
93     connTimeout_(0),
94     sendTimeout_(0),
95     recvTimeout_(0),
96     keepAlive_(false),
97     lingerOn_(1),
98     lingerVal_(0),
99     noDelay_(1),
100     maxRecvRetries_(5) {
101 }
102 
TSocket(const string & path,std::shared_ptr<TConfiguration> config)103 TSocket::TSocket(const string& path, std::shared_ptr<TConfiguration> config)
104   : TVirtualTransport(config),
105     port_(0),
106     path_(path),
107     socket_(THRIFT_INVALID_SOCKET),
108     peerPort_(0),
109     connTimeout_(0),
110     sendTimeout_(0),
111     recvTimeout_(0),
112     keepAlive_(false),
113     lingerOn_(1),
114     lingerVal_(0),
115     noDelay_(1),
116     maxRecvRetries_(5) {
117   cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
118 }
119 
TSocket(std::shared_ptr<TConfiguration> config)120 TSocket::TSocket(std::shared_ptr<TConfiguration> config)
121   : TVirtualTransport(config),
122     port_(0),
123     socket_(THRIFT_INVALID_SOCKET),
124     peerPort_(0),
125     connTimeout_(0),
126     sendTimeout_(0),
127     recvTimeout_(0),
128     keepAlive_(false),
129     lingerOn_(1),
130     lingerVal_(0),
131     noDelay_(1),
132     maxRecvRetries_(5) {
133   cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
134 }
135 
TSocket(THRIFT_SOCKET socket,std::shared_ptr<TConfiguration> config)136 TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<TConfiguration> config)
137   : TVirtualTransport(config),
138     port_(0),
139     socket_(socket),
140     peerPort_(0),
141     connTimeout_(0),
142     sendTimeout_(0),
143     recvTimeout_(0),
144     keepAlive_(false),
145     lingerOn_(1),
146     lingerVal_(0),
147     noDelay_(1),
148     maxRecvRetries_(5) {
149   cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
150 #ifdef SO_NOSIGPIPE
151   {
152     int one = 1;
153     setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
154   }
155 #endif
156 }
157 
TSocket(THRIFT_SOCKET socket,std::shared_ptr<THRIFT_SOCKET> interruptListener,std::shared_ptr<TConfiguration> config)158 TSocket::TSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener,
159                 std::shared_ptr<TConfiguration> config)
160   : TVirtualTransport(config),
161     port_(0),
162     socket_(socket),
163     peerPort_(0),
164     interruptListener_(interruptListener),
165     connTimeout_(0),
166     sendTimeout_(0),
167     recvTimeout_(0),
168     keepAlive_(false),
169     lingerOn_(1),
170     lingerVal_(0),
171     noDelay_(1),
172     maxRecvRetries_(5) {
173   cachedPeerAddr_.ipv4.sin_family = AF_UNSPEC;
174 #ifdef SO_NOSIGPIPE
175   {
176     int one = 1;
177     setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
178   }
179 #endif
180 }
181 
~TSocket()182 TSocket::~TSocket() {
183   close();
184 }
185 
hasPendingDataToRead()186 bool TSocket::hasPendingDataToRead() {
187   if (!isOpen()) {
188     return false;
189   }
190 
191   int32_t retries = 0;
192   THRIFT_IOCTL_SOCKET_NUM_BYTES_TYPE numBytesAvailable;
193 try_again:
194   int r = THRIFT_IOCTL_SOCKET(socket_, FIONREAD, &numBytesAvailable);
195   if (r == -1) {
196     int errno_copy = THRIFT_GET_SOCKET_ERROR;
197     if (errno_copy == THRIFT_EINTR && (retries++ < maxRecvRetries_)) {
198       goto try_again;
199     }
200     GlobalOutput.perror("TSocket::hasPendingDataToRead() THRIFT_IOCTL_SOCKET() " + getSocketInfo(), errno_copy);
201     throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
202   }
203   return numBytesAvailable > 0;
204 }
205 
isOpen() const206 bool TSocket::isOpen() const {
207   return (socket_ != THRIFT_INVALID_SOCKET);
208 }
209 
peek()210 bool TSocket::peek() {
211   if (!isOpen()) {
212     return false;
213   }
214   if (interruptListener_) {
215     for (int retries = 0;;) {
216       struct THRIFT_POLLFD fds[2];
217       std::memset(fds, 0, sizeof(fds));
218       fds[0].fd = socket_;
219       fds[0].events = THRIFT_POLLIN;
220       fds[1].fd = *(interruptListener_.get());
221       fds[1].events = THRIFT_POLLIN;
222       int ret = THRIFT_POLL(fds, 2, (recvTimeout_ == 0) ? -1 : recvTimeout_);
223       int errno_copy = THRIFT_GET_SOCKET_ERROR;
224       if (ret < 0) {
225         // error cases
226         if (errno_copy == THRIFT_EINTR && (retries++ < maxRecvRetries_)) {
227           continue;
228         }
229         GlobalOutput.perror("TSocket::peek() THRIFT_POLL() ", errno_copy);
230         throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
231       } else if (ret > 0) {
232         // Check the interruptListener
233         if (fds[1].revents & THRIFT_POLLIN) {
234           return false;
235         }
236         // There must be data or a disconnection, fall through to the PEEK
237         break;
238       } else {
239         // timeout
240         return false;
241       }
242     }
243   }
244 
245   // Check to see if data is available or if the remote side closed
246   uint8_t buf;
247   int r = static_cast<int>(recv(socket_, cast_sockopt(&buf), 1, MSG_PEEK));
248   if (r == -1) {
249     int errno_copy = THRIFT_GET_SOCKET_ERROR;
250 #if defined __FreeBSD__ || defined __MACH__
251     /* shigin:
252      * freebsd returns -1 and THRIFT_ECONNRESET if socket was closed by
253      * the other side
254      */
255     if (errno_copy == THRIFT_ECONNRESET) {
256       return false;
257     }
258 #endif
259     GlobalOutput.perror("TSocket::peek() recv() " + getSocketInfo(), errno_copy);
260     throw TTransportException(TTransportException::UNKNOWN, "recv()", errno_copy);
261   }
262   return (r > 0);
263 }
264 
openConnection(struct addrinfo * res)265 void TSocket::openConnection(struct addrinfo* res) {
266 
267   if (isOpen()) {
268     return;
269   }
270 
271   if (isUnixDomainSocket()) {
272     socket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
273   } else {
274     socket_ = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
275   }
276 
277   if (socket_ == THRIFT_INVALID_SOCKET) {
278     int errno_copy = THRIFT_GET_SOCKET_ERROR;
279     GlobalOutput.perror("TSocket::open() socket() " + getSocketInfo(), errno_copy);
280     throw TTransportException(TTransportException::NOT_OPEN, "socket()", errno_copy);
281   }
282 
283   // Send timeout
284   if (sendTimeout_ > 0) {
285     setSendTimeout(sendTimeout_);
286   }
287 
288   // Recv timeout
289   if (recvTimeout_ > 0) {
290     setRecvTimeout(recvTimeout_);
291   }
292 
293   if (keepAlive_) {
294     setKeepAlive(keepAlive_);
295   }
296 
297   // Linger
298   setLinger(lingerOn_, lingerVal_);
299 
300   // No delay
301   setNoDelay(noDelay_);
302 
303 #ifdef SO_NOSIGPIPE
304   {
305     int one = 1;
306     setsockopt(socket_, SOL_SOCKET, SO_NOSIGPIPE, &one, sizeof(one));
307   }
308 #endif
309 
310 // Uses a low min RTO if asked to.
311 #ifdef TCP_LOW_MIN_RTO
312   if (getUseLowMinRto()) {
313     int one = 1;
314     setsockopt(socket_, IPPROTO_TCP, TCP_LOW_MIN_RTO, &one, sizeof(one));
315   }
316 #endif
317 
318   // Set the socket to be non blocking for connect if a timeout exists
319   int flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0);
320   if (connTimeout_ > 0) {
321     if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK)) {
322       int errno_copy = THRIFT_GET_SOCKET_ERROR;
323       GlobalOutput.perror("TSocket::open() THRIFT_FCNTL() " + getSocketInfo(), errno_copy);
324       throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed", errno_copy);
325     }
326   } else {
327     if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags & ~THRIFT_O_NONBLOCK)) {
328       int errno_copy = THRIFT_GET_SOCKET_ERROR;
329       GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), errno_copy);
330       throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed", errno_copy);
331     }
332   }
333 
334   // Connect the socket
335   int ret;
336   if (isUnixDomainSocket()) {
337     // Windows supports Unix domain sockets since it ships the header
338     // HAVE_AF_UNIX_H (see https://devblogs.microsoft.com/commandline/af_unix-comes-to-windows/)
339 #if (!defined(_WIN32) || defined(HAVE_AF_UNIX_H))
340     struct sockaddr_un address;
341     socklen_t structlen = fillUnixSocketAddr(address, path_);
342 
343     ret = connect(socket_, (struct sockaddr*)&address, structlen);
344 #else
345     GlobalOutput.perror("TSocket::open() Unix Domain socket path not supported on this version of Windows", -99);
346     throw TTransportException(TTransportException::NOT_OPEN,
347                               " Unix Domain socket path not supported");
348 #endif
349   } else {
350     ret = connect(socket_, res->ai_addr, static_cast<int>(res->ai_addrlen));
351   }
352 
353   // success case
354   if (ret == 0) {
355     goto done;
356   }
357 
358   if ((THRIFT_GET_SOCKET_ERROR != THRIFT_EINPROGRESS)
359       && (THRIFT_GET_SOCKET_ERROR != THRIFT_EWOULDBLOCK)) {
360     int errno_copy = THRIFT_GET_SOCKET_ERROR;
361     GlobalOutput.perror("TSocket::open() connect() " + getSocketInfo(), errno_copy);
362     throw TTransportException(TTransportException::NOT_OPEN, "connect() failed", errno_copy);
363   }
364 
365   struct THRIFT_POLLFD fds[1];
366   std::memset(fds, 0, sizeof(fds));
367   fds[0].fd = socket_;
368   fds[0].events = THRIFT_POLLOUT;
369   ret = THRIFT_POLL(fds, 1, connTimeout_);
370 
371   if (ret > 0) {
372     // Ensure the socket is connected and that there are no errors set
373     int val;
374     socklen_t lon;
375     lon = sizeof(int);
376     int ret2 = getsockopt(socket_, SOL_SOCKET, SO_ERROR, cast_sockopt(&val), &lon);
377     if (ret2 == -1) {
378       int errno_copy = THRIFT_GET_SOCKET_ERROR;
379       GlobalOutput.perror("TSocket::open() getsockopt() " + getSocketInfo(), errno_copy);
380       throw TTransportException(TTransportException::NOT_OPEN, "getsockopt()", errno_copy);
381     }
382     // no errors on socket, go to town
383     if (val == 0) {
384       goto done;
385     }
386     GlobalOutput.perror("TSocket::open() error on socket (after THRIFT_POLL) " + getSocketInfo(),
387                         val);
388     throw TTransportException(TTransportException::NOT_OPEN, "socket open() error", val);
389   } else if (ret == 0) {
390     // socket timed out
391     string errStr = "TSocket::open() timed out " + getSocketInfo();
392     GlobalOutput(errStr.c_str());
393     throw TTransportException(TTransportException::NOT_OPEN, "open() timed out");
394   } else {
395     // error on THRIFT_POLL()
396     int errno_copy = THRIFT_GET_SOCKET_ERROR;
397     GlobalOutput.perror("TSocket::open() THRIFT_POLL() " + getSocketInfo(), errno_copy);
398     throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_POLL() failed", errno_copy);
399   }
400 
401 done:
402   // Set socket back to normal mode (blocking)
403   if (-1 == THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags)) {
404     int errno_copy = THRIFT_GET_SOCKET_ERROR;
405     GlobalOutput.perror("TSocket::open() THRIFT_FCNTL " + getSocketInfo(), errno_copy);
406     throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_FCNTL() failed", errno_copy);
407   }
408 
409   if (!isUnixDomainSocket()) {
410     setCachedAddress(res->ai_addr, static_cast<socklen_t>(res->ai_addrlen));
411   }
412 }
413 
open()414 void TSocket::open() {
415   if (isOpen()) {
416     return;
417   }
418   if (isUnixDomainSocket()) {
419     unix_open();
420   } else {
421     local_open();
422   }
423 }
424 
unix_open()425 void TSocket::unix_open() {
426   if (isUnixDomainSocket()) {
427     // Unix Domain Socket does not need addrinfo struct, so we pass NULL
428     openConnection(nullptr);
429   }
430 }
431 
local_open()432 void TSocket::local_open() {
433 
434 #ifdef _WIN32
435   TWinsockSingleton::create();
436 #endif // _WIN32
437 
438   if (isOpen()) {
439     return;
440   }
441 
442   // Validate port number
443   if (port_ < 0 || port_ > 0xFFFF) {
444     throw TTransportException(TTransportException::BAD_ARGS, "Specified port is invalid");
445   }
446 
447   struct addrinfo hints, *res, *res0;
448   res = nullptr;
449   res0 = nullptr;
450   int error;
451   char port[sizeof("65535")];
452   std::memset(&hints, 0, sizeof(hints));
453   hints.ai_family = PF_UNSPEC;
454   hints.ai_socktype = SOCK_STREAM;
455   hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
456   sprintf(port, "%d", port_);
457 
458   error = getaddrinfo(host_.c_str(), port, &hints, &res0);
459 
460   if (
461 #ifdef _WIN32
462       error == WSANO_DATA
463 #else
464       error == EAI_NODATA
465 #endif
466     ) {
467     hints.ai_flags &= ~AI_ADDRCONFIG;
468     error = getaddrinfo(host_.c_str(), port, &hints, &res0);
469   }
470 
471   if (error) {
472     string errStr = "TSocket::open() getaddrinfo() " + getSocketInfo()
473                     + string(THRIFT_GAI_STRERROR(error));
474     GlobalOutput(errStr.c_str());
475     close();
476     throw TTransportException(TTransportException::NOT_OPEN,
477                               "Could not resolve host for client socket.");
478   }
479 
480   // Cycle through all the returned addresses until one
481   // connects or push the exception up.
482   for (res = res0; res; res = res->ai_next) {
483     try {
484       openConnection(res);
485       break;
486     } catch (TTransportException&) {
487       if (res->ai_next) {
488         close();
489       } else {
490         close();
491         freeaddrinfo(res0); // cleanup on failure
492         throw;
493       }
494     }
495   }
496 
497   // Free address structure memory
498   freeaddrinfo(res0);
499 }
500 
close()501 void TSocket::close() {
502   if (socket_ != THRIFT_INVALID_SOCKET) {
503     shutdown(socket_, THRIFT_SHUT_RDWR);
504     ::THRIFT_CLOSESOCKET(socket_);
505   }
506   socket_ = THRIFT_INVALID_SOCKET;
507 }
508 
setSocketFD(THRIFT_SOCKET socket)509 void TSocket::setSocketFD(THRIFT_SOCKET socket) {
510   if (socket_ != THRIFT_INVALID_SOCKET) {
511     close();
512   }
513   socket_ = socket;
514 }
515 
read(uint8_t * buf,uint32_t len)516 uint32_t TSocket::read(uint8_t* buf, uint32_t len) {
517   checkReadBytesAvailable(len);
518   if (socket_ == THRIFT_INVALID_SOCKET) {
519     throw TTransportException(TTransportException::NOT_OPEN, "Called read on non-open socket");
520   }
521 
522   int32_t retries = 0;
523 
524   // THRIFT_EAGAIN can be signalled both when a timeout has occurred and when
525   // the system is out of resources (an awesome undocumented feature).
526   // The following is an approximation of the time interval under which
527   // THRIFT_EAGAIN is taken to indicate an out of resources error.
528   uint32_t eagainThresholdMicros = 0;
529   if (recvTimeout_) {
530     // if a readTimeout is specified along with a max number of recv retries, then
531     // the threshold will ensure that the read timeout is not exceeded even in the
532     // case of resource errors
533     eagainThresholdMicros = (recvTimeout_ * 1000) / ((maxRecvRetries_ > 0) ? maxRecvRetries_ : 2);
534   }
535 
536 try_again:
537   // Read from the socket
538   struct timeval begin;
539   if (recvTimeout_ > 0) {
540     THRIFT_GETTIMEOFDAY(&begin, nullptr);
541   } else {
542     // if there is no read timeout we don't need the TOD to determine whether
543     // an THRIFT_EAGAIN is due to a timeout or an out-of-resource condition.
544     begin.tv_sec = begin.tv_usec = 0;
545   }
546 
547   int got = 0;
548 
549   if (interruptListener_) {
550     struct THRIFT_POLLFD fds[2];
551     std::memset(fds, 0, sizeof(fds));
552     fds[0].fd = socket_;
553     fds[0].events = THRIFT_POLLIN;
554     fds[1].fd = *(interruptListener_.get());
555     fds[1].events = THRIFT_POLLIN;
556 
557     int ret = THRIFT_POLL(fds, 2, (recvTimeout_ == 0) ? -1 : recvTimeout_);
558     int errno_copy = THRIFT_GET_SOCKET_ERROR;
559     if (ret < 0) {
560       // error cases
561       if (errno_copy == THRIFT_EINTR && (retries++ < maxRecvRetries_)) {
562         goto try_again;
563       }
564       GlobalOutput.perror("TSocket::read() THRIFT_POLL() ", errno_copy);
565       throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
566     } else if (ret > 0) {
567       // Check the interruptListener
568       if (fds[1].revents & THRIFT_POLLIN) {
569         throw TTransportException(TTransportException::INTERRUPTED, "Interrupted");
570       }
571     } else /* ret == 0 */ {
572       GlobalOutput.perror("TSocket::read() THRIFT_EAGAIN (timed out) after %f ms", recvTimeout_);
573       throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
574     }
575 
576     // falling through means there is something to recv and it cannot block
577   }
578 
579   got = static_cast<int>(recv(socket_, cast_sockopt(buf), len, 0));
580   // THRIFT_GETTIMEOFDAY can change THRIFT_GET_SOCKET_ERROR
581   int errno_copy = THRIFT_GET_SOCKET_ERROR;
582 
583   // Check for error on read
584   if (got < 0) {
585     if (errno_copy == THRIFT_EAGAIN) {
586       // if no timeout we can assume that resource exhaustion has occurred.
587       if (recvTimeout_ == 0) {
588         throw TTransportException(TTransportException::TIMED_OUT,
589                                   "THRIFT_EAGAIN (unavailable resources)");
590       }
591       // check if this is the lack of resources or timeout case
592       struct timeval end;
593       THRIFT_GETTIMEOFDAY(&end, nullptr);
594       auto readElapsedMicros = static_cast<uint32_t>(((end.tv_sec - begin.tv_sec) * 1000 * 1000)
595                                                          + (end.tv_usec - begin.tv_usec));
596 
597       if (!eagainThresholdMicros || (readElapsedMicros < eagainThresholdMicros)) {
598         if (retries++ < maxRecvRetries_) {
599           THRIFT_SLEEP_USEC(50);
600           goto try_again;
601         } else {
602           throw TTransportException(TTransportException::TIMED_OUT,
603                                     "THRIFT_EAGAIN (unavailable resources)");
604         }
605       } else {
606         // infer that timeout has been hit
607         throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_EAGAIN (timed out)");
608       }
609     }
610 
611     // If interrupted, try again
612     if (errno_copy == THRIFT_EINTR && retries++ < maxRecvRetries_) {
613       goto try_again;
614     }
615 
616     if (errno_copy == THRIFT_ECONNRESET) {
617       return 0;
618     }
619 
620     // This ish isn't open
621     if (errno_copy == THRIFT_ENOTCONN) {
622       throw TTransportException(TTransportException::NOT_OPEN, "THRIFT_ENOTCONN");
623     }
624 
625     // Timed out!
626     if (errno_copy == THRIFT_ETIMEDOUT) {
627       throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_ETIMEDOUT");
628     }
629 
630     // Now it's not a try again case, but a real probblez
631     GlobalOutput.perror("TSocket::read() recv() " + getSocketInfo(), errno_copy);
632 
633     // Some other error, whatevz
634     throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
635   }
636 
637   return got;
638 }
639 
write(const uint8_t * buf,uint32_t len)640 void TSocket::write(const uint8_t* buf, uint32_t len) {
641   uint32_t sent = 0;
642 
643   while (sent < len) {
644     uint32_t b = write_partial(buf + sent, len - sent);
645     if (b == 0) {
646       // This should only happen if the timeout set with SO_SNDTIMEO expired.
647       // Raise an exception.
648       throw TTransportException(TTransportException::TIMED_OUT, "send timeout expired");
649     }
650     sent += b;
651   }
652 }
653 
write_partial(const uint8_t * buf,uint32_t len)654 uint32_t TSocket::write_partial(const uint8_t* buf, uint32_t len) {
655   if (socket_ == THRIFT_INVALID_SOCKET) {
656     throw TTransportException(TTransportException::NOT_OPEN, "Called write on non-open socket");
657   }
658 
659   uint32_t sent = 0;
660 
661   int flags = 0;
662 #ifdef MSG_NOSIGNAL
663   // Note the use of MSG_NOSIGNAL to suppress SIGPIPE errors, instead we
664   // check for the THRIFT_EPIPE return condition and close the socket in that case
665   flags |= MSG_NOSIGNAL;
666 #endif // ifdef MSG_NOSIGNAL
667 
668   int b = static_cast<int>(send(socket_, const_cast_sockopt(buf + sent), len - sent, flags));
669 
670   if (b < 0) {
671     if (THRIFT_GET_SOCKET_ERROR == THRIFT_EWOULDBLOCK || THRIFT_GET_SOCKET_ERROR == THRIFT_EAGAIN) {
672       return 0;
673     }
674     // Fail on a send error
675     int errno_copy = THRIFT_GET_SOCKET_ERROR;
676     GlobalOutput.perror("TSocket::write_partial() send() " + getSocketInfo(), errno_copy);
677 
678     if (errno_copy == THRIFT_EPIPE || errno_copy == THRIFT_ECONNRESET
679         || errno_copy == THRIFT_ENOTCONN) {
680       throw TTransportException(TTransportException::NOT_OPEN, "write() send()", errno_copy);
681     }
682 
683     throw TTransportException(TTransportException::UNKNOWN, "write() send()", errno_copy);
684   }
685 
686   // Fail on blocked send
687   if (b == 0) {
688     throw TTransportException(TTransportException::NOT_OPEN, "Socket send returned 0.");
689   }
690   return b;
691 }
692 
getHost() const693 std::string TSocket::getHost() const {
694   return host_;
695 }
696 
getPort() const697 int TSocket::getPort() const {
698   return port_;
699 }
700 
getPath() const701 std::string TSocket::getPath() const {
702     return path_;
703 }
704 
isUnixDomainSocket() const705 bool TSocket::isUnixDomainSocket() const {
706     return !path_.empty();
707 }
708 
setHost(string host)709 void TSocket::setHost(string host) {
710   host_ = host;
711 }
712 
setPort(int port)713 void TSocket::setPort(int port) {
714   port_ = port;
715 }
716 
setPath(std::string path)717 void TSocket::setPath(std::string path) {
718     path_ = path;
719 }
720 
setLinger(bool on,int linger)721 void TSocket::setLinger(bool on, int linger) {
722   lingerOn_ = on;
723   lingerVal_ = linger;
724   if (socket_ == THRIFT_INVALID_SOCKET) {
725     return;
726   }
727 
728 #ifndef _WIN32
729   struct linger l = {(lingerOn_ ? 1 : 0), lingerVal_};
730 #else
731   struct linger l = {static_cast<u_short>(lingerOn_ ? 1 : 0), static_cast<u_short>(lingerVal_)};
732 #endif
733 
734   int ret = setsockopt(socket_, SOL_SOCKET, SO_LINGER, cast_sockopt(&l), sizeof(l));
735   if (ret == -1) {
736     int errno_copy
737         = THRIFT_GET_SOCKET_ERROR; // Copy THRIFT_GET_SOCKET_ERROR because we're allocating memory.
738     GlobalOutput.perror("TSocket::setLinger() setsockopt() " + getSocketInfo(), errno_copy);
739   }
740 }
741 
setNoDelay(bool noDelay)742 void TSocket::setNoDelay(bool noDelay) {
743   noDelay_ = noDelay;
744   if (socket_ == THRIFT_INVALID_SOCKET || isUnixDomainSocket()) {
745     return;
746   }
747 
748   // Set socket to NODELAY
749   int v = noDelay_ ? 1 : 0;
750   int ret = setsockopt(socket_, IPPROTO_TCP, TCP_NODELAY, cast_sockopt(&v), sizeof(v));
751   if (ret == -1) {
752     int errno_copy
753         = THRIFT_GET_SOCKET_ERROR; // Copy THRIFT_GET_SOCKET_ERROR because we're allocating memory.
754     GlobalOutput.perror("TSocket::setNoDelay() setsockopt() " + getSocketInfo(), errno_copy);
755   }
756 }
757 
setConnTimeout(int ms)758 void TSocket::setConnTimeout(int ms) {
759   connTimeout_ = ms;
760 }
761 
setGenericTimeout(THRIFT_SOCKET s,int timeout_ms,int optname)762 void setGenericTimeout(THRIFT_SOCKET s, int timeout_ms, int optname) {
763   if (timeout_ms < 0) {
764     char errBuf[512];
765     sprintf(errBuf, "TSocket::setGenericTimeout with negative input: %d\n", timeout_ms);
766     GlobalOutput(errBuf);
767     return;
768   }
769 
770   if (s == THRIFT_INVALID_SOCKET) {
771     return;
772   }
773 
774 #ifdef _WIN32
775   DWORD platform_time = static_cast<DWORD>(timeout_ms);
776 #else
777   struct timeval platform_time = {(int)(timeout_ms / 1000), (int)((timeout_ms % 1000) * 1000)};
778 #endif
779 
780   int ret = setsockopt(s, SOL_SOCKET, optname, cast_sockopt(&platform_time), sizeof(platform_time));
781   if (ret == -1) {
782     int errno_copy
783         = THRIFT_GET_SOCKET_ERROR; // Copy THRIFT_GET_SOCKET_ERROR because we're allocating memory.
784     GlobalOutput.perror("TSocket::setGenericTimeout() setsockopt() ", errno_copy);
785   }
786 }
787 
setRecvTimeout(int ms)788 void TSocket::setRecvTimeout(int ms) {
789   setGenericTimeout(socket_, ms, SO_RCVTIMEO);
790   recvTimeout_ = ms;
791 }
792 
setSendTimeout(int ms)793 void TSocket::setSendTimeout(int ms) {
794   setGenericTimeout(socket_, ms, SO_SNDTIMEO);
795   sendTimeout_ = ms;
796 }
797 
setKeepAlive(bool keepAlive)798 void TSocket::setKeepAlive(bool keepAlive) {
799   keepAlive_ = keepAlive;
800 
801   if (socket_ == THRIFT_INVALID_SOCKET) {
802     return;
803   }
804 
805 #ifdef _WIN32
806   if (isUnixDomainSocket()) {
807       // Windows Domain sockets do not support SO_KEEPALIVE.
808       return;
809   }
810 #endif
811 
812   int value = keepAlive_;
813   int ret
814       = setsockopt(socket_, SOL_SOCKET, SO_KEEPALIVE, const_cast_sockopt(&value), sizeof(value));
815 
816   if (ret == -1) {
817     int errno_copy
818         = THRIFT_GET_SOCKET_ERROR; // Copy THRIFT_GET_SOCKET_ERROR because we're allocating memory.
819     GlobalOutput.perror("TSocket::setKeepAlive() setsockopt() " + getSocketInfo(), errno_copy);
820   }
821 }
822 
setMaxRecvRetries(int maxRecvRetries)823 void TSocket::setMaxRecvRetries(int maxRecvRetries) {
824   maxRecvRetries_ = maxRecvRetries;
825 }
826 
getSocketInfo() const827 string TSocket::getSocketInfo() const {
828   std::ostringstream oss;
829   if (!isUnixDomainSocket()) {
830     if (host_.empty() || port_ == 0) {
831       oss << "<Host: " << getPeerAddress();
832       oss << " Port: " << getPeerPort() << ">";
833     } else {
834       oss << "<Host: " << host_ << " Port: " << port_ << ">";
835     }
836   } else {
837     std::string fmt_path_ = path_;
838     // Handle printing abstract sockets (first character is a '\0' char):
839     if (!fmt_path_.empty() && fmt_path_[0] == '\0')
840       fmt_path_[0] = '@';
841     oss << "<Path: " << fmt_path_ << ">";
842   }
843   return oss.str();
844 }
845 
getPeerHost() const846 std::string TSocket::getPeerHost() const {
847   if (peerHost_.empty() && !isUnixDomainSocket()) {
848     struct sockaddr_storage addr;
849     struct sockaddr* addrPtr;
850     socklen_t addrLen;
851 
852     if (socket_ == THRIFT_INVALID_SOCKET) {
853       return host_;
854     }
855 
856     addrPtr = getCachedAddress(&addrLen);
857 
858     if (addrPtr == nullptr) {
859       addrLen = sizeof(addr);
860       if (getpeername(socket_, (sockaddr*)&addr, &addrLen) != 0) {
861         return peerHost_;
862       }
863       addrPtr = (sockaddr*)&addr;
864 
865       const_cast<TSocket&>(*this).setCachedAddress(addrPtr, addrLen);
866     }
867 
868     char clienthost[NI_MAXHOST];
869     char clientservice[NI_MAXSERV];
870 
871     getnameinfo((sockaddr*)addrPtr,
872                 addrLen,
873                 clienthost,
874                 sizeof(clienthost),
875                 clientservice,
876                 sizeof(clientservice),
877                 0);
878 
879     peerHost_ = clienthost;
880   }
881   return peerHost_;
882 }
883 
getPeerAddress() const884 std::string TSocket::getPeerAddress() const {
885   if (peerAddress_.empty() && !isUnixDomainSocket()) {
886     struct sockaddr_storage addr;
887     struct sockaddr* addrPtr;
888     socklen_t addrLen;
889 
890     if (socket_ == THRIFT_INVALID_SOCKET) {
891       return peerAddress_;
892     }
893 
894     addrPtr = getCachedAddress(&addrLen);
895 
896     if (addrPtr == nullptr) {
897       addrLen = sizeof(addr);
898       if (getpeername(socket_, (sockaddr*)&addr, &addrLen) != 0) {
899         return peerAddress_;
900       }
901       addrPtr = (sockaddr*)&addr;
902 
903       const_cast<TSocket&>(*this).setCachedAddress(addrPtr, addrLen);
904     }
905 
906     char clienthost[NI_MAXHOST];
907     char clientservice[NI_MAXSERV];
908 
909     getnameinfo(addrPtr,
910                 addrLen,
911                 clienthost,
912                 sizeof(clienthost),
913                 clientservice,
914                 sizeof(clientservice),
915                 NI_NUMERICHOST | NI_NUMERICSERV);
916 
917     peerAddress_ = clienthost;
918     peerPort_ = std::atoi(clientservice);
919   }
920   return peerAddress_;
921 }
922 
getPeerPort() const923 int TSocket::getPeerPort() const {
924   getPeerAddress();
925   return peerPort_;
926 }
927 
setCachedAddress(const sockaddr * addr,socklen_t len)928 void TSocket::setCachedAddress(const sockaddr* addr, socklen_t len) {
929   if (isUnixDomainSocket()) {
930     return;
931   }
932 
933   switch (addr->sa_family) {
934   case AF_INET:
935     if (len == sizeof(sockaddr_in)) {
936       memcpy((void*)&cachedPeerAddr_.ipv4, (void*)addr, len);
937     }
938     break;
939 
940   case AF_INET6:
941     if (len == sizeof(sockaddr_in6)) {
942       memcpy((void*)&cachedPeerAddr_.ipv6, (void*)addr, len);
943     }
944     break;
945   }
946   peerAddress_.clear();
947   peerHost_.clear();
948 }
949 
getCachedAddress(socklen_t * len) const950 sockaddr* TSocket::getCachedAddress(socklen_t* len) const {
951   switch (cachedPeerAddr_.ipv4.sin_family) {
952   case AF_INET:
953     *len = sizeof(sockaddr_in);
954     return (sockaddr*)&cachedPeerAddr_.ipv4;
955 
956   case AF_INET6:
957     *len = sizeof(sockaddr_in6);
958     return (sockaddr*)&cachedPeerAddr_.ipv6;
959 
960   default:
961     return nullptr;
962   }
963 }
964 
965 bool TSocket::useLowMinRto_ = false;
setUseLowMinRto(bool useLowMinRto)966 void TSocket::setUseLowMinRto(bool useLowMinRto) {
967   useLowMinRto_ = useLowMinRto;
968 }
getUseLowMinRto()969 bool TSocket::getUseLowMinRto() {
970   return useLowMinRto_;
971 }
972 
getOrigin() const973 const std::string TSocket::getOrigin() const {
974   std::ostringstream oss;
975   oss << getPeerHost() << ":" << getPeerPort();
976   return oss.str();
977 }
978 }
979 }
980 } // apache::thrift::transport
981