1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"errors"
24	"net"
25	"sync/atomic"
26)
27
28// socketConn is a wrapped net.Conn that tries to do connectivity check.
29type socketConn struct {
30	net.Conn
31
32	buffer [1]byte
33	closed int32
34}
35
36var _ net.Conn = (*socketConn)(nil)
37
38// createSocketConnFromReturn is a language sugar to help create socketConn from
39// return values of functions like net.Dial, tls.Dial, net.Listener.Accept, etc.
40func createSocketConnFromReturn(conn net.Conn, err error) (*socketConn, error) {
41	if err != nil {
42		return nil, err
43	}
44	return &socketConn{
45		Conn: conn,
46	}, nil
47}
48
49// wrapSocketConn wraps an existing net.Conn into *socketConn.
50func wrapSocketConn(conn net.Conn) *socketConn {
51	// In case conn is already wrapped,
52	// return it as-is and avoid double wrapping.
53	if sc, ok := conn.(*socketConn); ok {
54		return sc
55	}
56
57	return &socketConn{
58		Conn: conn,
59	}
60}
61
62// isValid checks whether there's a valid connection.
63//
64// It's nil safe, and returns false if sc itself is nil, or if the underlying
65// connection is nil.
66//
67// It's the same as the previous implementation of TSocket.IsOpen and
68// TSSLSocket.IsOpen before we added connectivity check.
69func (sc *socketConn) isValid() bool {
70	return sc != nil && sc.Conn != nil && atomic.LoadInt32(&sc.closed) == 0
71}
72
73// IsOpen checks whether the connection is open.
74//
75// It's nil safe, and returns false if sc itself is nil, or if the underlying
76// connection is nil.
77//
78// Otherwise, it tries to do a connectivity check and returns the result.
79//
80// It also has the side effect of resetting the previously set read deadline on
81// the socket. As a result, it shouldn't be called between setting read deadline
82// and doing actual read.
83func (sc *socketConn) IsOpen() bool {
84	if !sc.isValid() {
85		return false
86	}
87	if err := sc.checkConn(); err != nil {
88		if !errors.Is(err, net.ErrClosed) {
89			// The connectivity check failed and the error is not
90			// that the connection is already closed, we need to
91			// close the connection explicitly here to avoid
92			// connection leaks.
93			sc.Close()
94		}
95		return false
96	}
97	return true
98}
99
100// Read implements io.Reader.
101//
102// On Windows, it behaves the same as the underlying net.Conn.Read.
103//
104// On non-Windows, it treats len(p) == 0 as a connectivity check instead of
105// readability check, which means instead of blocking until there's something to
106// read (readability check), or always return (0, nil) (the default behavior of
107// go's stdlib implementation on non-Windows), it never blocks, and will return
108// an error if the connection is lost.
109func (sc *socketConn) Read(p []byte) (n int, err error) {
110	if len(p) == 0 {
111		return 0, sc.read0()
112	}
113
114	return sc.Conn.Read(p)
115}
116
117func (sc *socketConn) Close() error {
118	if !sc.isValid() {
119		// Already closed
120		return net.ErrClosed
121	}
122	atomic.StoreInt32(&sc.closed, 1)
123	return sc.Conn.Close()
124}
125