1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"crypto/tls"
24	"fmt"
25	"time"
26)
27
28// Default TConfiguration values.
29const (
30	DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024
31	DEFAULT_MAX_FRAME_SIZE   = 16384000
32
33	DEFAULT_TBINARY_STRICT_READ  = false
34	DEFAULT_TBINARY_STRICT_WRITE = true
35
36	DEFAULT_CONNECT_TIMEOUT = 0
37	DEFAULT_SOCKET_TIMEOUT  = 0
38)
39
40// TConfiguration defines some configurations shared between TTransport,
41// TProtocol, TTransportFactory, TProtocolFactory, and other implementations.
42//
43// When constructing TConfiguration, you only need to specify the non-default
44// fields. All zero values have sane default values.
45//
46// Not all configurations defined are applicable to all implementations.
47// Implementations are free to ignore the configurations not applicable to them.
48//
49// All functions attached to this type are nil-safe.
50//
51// See [1] for spec.
52//
53// NOTE: When using TConfiguration, fill in all the configurations you want to
54// set across the stack, not only the ones you want to set in the immediate
55// TTransport/TProtocol.
56//
57// For example, say you want to migrate this old code into using TConfiguration:
58//
59//     sccket, err := thrift.NewTSocketTimeout("host:port", time.Second, time.Second)
60//     transFactory := thrift.NewTFramedTransportFactoryMaxLength(
61//         thrift.NewTTransportFactory(),
62//         1024 * 1024 * 256,
63//     )
64//     protoFactory := thrift.NewTBinaryProtocolFactory(true, true)
65//
66// This is the wrong way to do it because in the end the TConfiguration used by
67// socket and transFactory will be overwritten by the one used by protoFactory
68// because of TConfiguration propagation:
69//
70//     // bad example, DO NOT USE
71//     sccket := thrift.NewTSocketConf("host:port", &thrift.TConfiguration{
72//         ConnectTimeout: time.Second,
73//         SocketTimeout:  time.Second,
74//     })
75//     transFactory := thrift.NewTFramedTransportFactoryConf(
76//         thrift.NewTTransportFactory(),
77//         &thrift.TConfiguration{
78//             MaxFrameSize: 1024 * 1024 * 256,
79//         },
80//     )
81//     protoFactory := thrift.NewTBinaryProtocolFactoryConf(&thrift.TConfiguration{
82//         TBinaryStrictRead:  thrift.BoolPtr(true),
83//         TBinaryStrictWrite: thrift.BoolPtr(true),
84//     })
85//
86// This is the correct way to do it:
87//
88//     conf := &thrift.TConfiguration{
89//         ConnectTimeout: time.Second,
90//         SocketTimeout:  time.Second,
91//
92//         MaxFrameSize: 1024 * 1024 * 256,
93//
94//         TBinaryStrictRead:  thrift.BoolPtr(true),
95//         TBinaryStrictWrite: thrift.BoolPtr(true),
96//     }
97//     sccket := thrift.NewTSocketConf("host:port", conf)
98//     transFactory := thrift.NewTFramedTransportFactoryConf(thrift.NewTTransportFactory(), conf)
99//     protoFactory := thrift.NewTBinaryProtocolFactoryConf(conf)
100//
101// [1]: https://github.com/apache/thrift/blob/master/doc/specs/thrift-tconfiguration.md
102type TConfiguration struct {
103	// If <= 0, DEFAULT_MAX_MESSAGE_SIZE will be used instead.
104	MaxMessageSize int32
105
106	// If <= 0, DEFAULT_MAX_FRAME_SIZE will be used instead.
107	//
108	// Also if MaxMessageSize < MaxFrameSize,
109	// MaxMessageSize will be used instead.
110	MaxFrameSize int32
111
112	// Connect and socket timeouts to be used by TSocket and TSSLSocket.
113	//
114	// 0 means no timeout.
115	//
116	// If <0, DEFAULT_CONNECT_TIMEOUT and DEFAULT_SOCKET_TIMEOUT will be
117	// used.
118	ConnectTimeout time.Duration
119	SocketTimeout  time.Duration
120
121	// TLS config to be used by TSSLSocket.
122	TLSConfig *tls.Config
123
124	// Strict read/write configurations for TBinaryProtocol.
125	//
126	// BoolPtr helper function is available to use literal values.
127	TBinaryStrictRead  *bool
128	TBinaryStrictWrite *bool
129
130	// The wrapped protocol id to be used in THeader transport/protocol.
131	//
132	// THeaderProtocolIDPtr and THeaderProtocolIDPtrMust helper functions
133	// are provided to help filling this value.
134	THeaderProtocolID *THeaderProtocolID
135
136	// Used internally by deprecated constructors, to avoid overriding
137	// underlying TTransport/TProtocol's cfg by accidental propagations.
138	//
139	// For external users this is always false.
140	noPropagation bool
141}
142
143// GetMaxMessageSize returns the max message size an implementation should
144// follow.
145//
146// It's nil-safe. DEFAULT_MAX_MESSAGE_SIZE will be returned if tc is nil.
147func (tc *TConfiguration) GetMaxMessageSize() int32 {
148	if tc == nil || tc.MaxMessageSize <= 0 {
149		return DEFAULT_MAX_MESSAGE_SIZE
150	}
151	return tc.MaxMessageSize
152}
153
154// GetMaxFrameSize returns the max frame size an implementation should follow.
155//
156// It's nil-safe. DEFAULT_MAX_FRAME_SIZE will be returned if tc is nil.
157//
158// If the configured max message size is smaller than the configured max frame
159// size, the smaller one will be returned instead.
160func (tc *TConfiguration) GetMaxFrameSize() int32 {
161	if tc == nil {
162		return DEFAULT_MAX_FRAME_SIZE
163	}
164	maxFrameSize := tc.MaxFrameSize
165	if maxFrameSize <= 0 {
166		maxFrameSize = DEFAULT_MAX_FRAME_SIZE
167	}
168	if maxMessageSize := tc.GetMaxMessageSize(); maxMessageSize < maxFrameSize {
169		return maxMessageSize
170	}
171	return maxFrameSize
172}
173
174// GetConnectTimeout returns the connect timeout should be used by TSocket and
175// TSSLSocket.
176//
177// It's nil-safe. If tc is nil, DEFAULT_CONNECT_TIMEOUT will be returned instead.
178func (tc *TConfiguration) GetConnectTimeout() time.Duration {
179	if tc == nil || tc.ConnectTimeout < 0 {
180		return DEFAULT_CONNECT_TIMEOUT
181	}
182	return tc.ConnectTimeout
183}
184
185// GetSocketTimeout returns the socket timeout should be used by TSocket and
186// TSSLSocket.
187//
188// It's nil-safe. If tc is nil, DEFAULT_SOCKET_TIMEOUT will be returned instead.
189func (tc *TConfiguration) GetSocketTimeout() time.Duration {
190	if tc == nil || tc.SocketTimeout < 0 {
191		return DEFAULT_SOCKET_TIMEOUT
192	}
193	return tc.SocketTimeout
194}
195
196// GetTLSConfig returns the tls config should be used by TSSLSocket.
197//
198// It's nil-safe. If tc is nil, nil will be returned instead.
199func (tc *TConfiguration) GetTLSConfig() *tls.Config {
200	if tc == nil {
201		return nil
202	}
203	return tc.TLSConfig
204}
205
206// GetTBinaryStrictRead returns the strict read configuration TBinaryProtocol
207// should follow.
208//
209// It's nil-safe. DEFAULT_TBINARY_STRICT_READ will be returned if either tc or
210// tc.TBinaryStrictRead is nil.
211func (tc *TConfiguration) GetTBinaryStrictRead() bool {
212	if tc == nil || tc.TBinaryStrictRead == nil {
213		return DEFAULT_TBINARY_STRICT_READ
214	}
215	return *tc.TBinaryStrictRead
216}
217
218// GetTBinaryStrictWrite returns the strict read configuration TBinaryProtocol
219// should follow.
220//
221// It's nil-safe. DEFAULT_TBINARY_STRICT_WRITE will be returned if either tc or
222// tc.TBinaryStrictWrite is nil.
223func (tc *TConfiguration) GetTBinaryStrictWrite() bool {
224	if tc == nil || tc.TBinaryStrictWrite == nil {
225		return DEFAULT_TBINARY_STRICT_WRITE
226	}
227	return *tc.TBinaryStrictWrite
228}
229
230// GetTHeaderProtocolID returns the THeaderProtocolID should be used by
231// THeaderProtocol clients (for servers, they always use the same one as the
232// client instead).
233//
234// It's nil-safe. If either tc or tc.THeaderProtocolID is nil,
235// THeaderProtocolDefault will be returned instead.
236// THeaderProtocolDefault will also be returned if configured value is invalid.
237func (tc *TConfiguration) GetTHeaderProtocolID() THeaderProtocolID {
238	if tc == nil || tc.THeaderProtocolID == nil {
239		return THeaderProtocolDefault
240	}
241	protoID := *tc.THeaderProtocolID
242	if err := protoID.Validate(); err != nil {
243		return THeaderProtocolDefault
244	}
245	return protoID
246}
247
248// THeaderProtocolIDPtr validates and returns the pointer to id.
249//
250// If id is not a valid THeaderProtocolID, a pointer to THeaderProtocolDefault
251// and the validation error will be returned.
252func THeaderProtocolIDPtr(id THeaderProtocolID) (*THeaderProtocolID, error) {
253	err := id.Validate()
254	if err != nil {
255		id = THeaderProtocolDefault
256	}
257	return &id, err
258}
259
260// THeaderProtocolIDPtrMust validates and returns the pointer to id.
261//
262// It's similar to THeaderProtocolIDPtr, but it panics on validation errors
263// instead of returning them.
264func THeaderProtocolIDPtrMust(id THeaderProtocolID) *THeaderProtocolID {
265	ptr, err := THeaderProtocolIDPtr(id)
266	if err != nil {
267		panic(err)
268	}
269	return ptr
270}
271
272// TConfigurationSetter is an optional interface TProtocol, TTransport,
273// TProtocolFactory, TTransportFactory, and other implementations can implement.
274//
275// It's intended to be called during intializations.
276// The behavior of calling SetTConfiguration on a TTransport/TProtocol in the
277// middle of a message is undefined:
278// It may or may not change the behavior of the current processing message,
279// and it may even cause the current message to fail.
280//
281// Note for implementations: SetTConfiguration might be called multiple times
282// with the same value in quick successions due to the implementation of the
283// propagation. Implementations should make SetTConfiguration as simple as
284// possible (usually just overwrite the stored configuration and propagate it to
285// the wrapped TTransports/TProtocols).
286type TConfigurationSetter interface {
287	SetTConfiguration(*TConfiguration)
288}
289
290// PropagateTConfiguration propagates cfg to impl if impl implements
291// TConfigurationSetter and cfg is non-nil, otherwise it does nothing.
292//
293// NOTE: nil cfg is not propagated. If you want to propagate a TConfiguration
294// with everything being default value, use &TConfiguration{} explicitly instead.
295func PropagateTConfiguration(impl interface{}, cfg *TConfiguration) {
296	if cfg == nil || cfg.noPropagation {
297		return
298	}
299
300	if setter, ok := impl.(TConfigurationSetter); ok {
301		setter.SetTConfiguration(cfg)
302	}
303}
304
305func checkSizeForProtocol(size int32, cfg *TConfiguration) error {
306	if size < 0 {
307		return NewTProtocolExceptionWithType(
308			NEGATIVE_SIZE,
309			fmt.Errorf("negative size: %d", size),
310		)
311	}
312	if size > cfg.GetMaxMessageSize() {
313		return NewTProtocolExceptionWithType(
314			SIZE_LIMIT,
315			fmt.Errorf("size exceeded max allowed: %d", size),
316		)
317	}
318	return nil
319}
320
321type tTransportFactoryConf struct {
322	delegate TTransportFactory
323	cfg      *TConfiguration
324}
325
326func (f *tTransportFactoryConf) GetTransport(orig TTransport) (TTransport, error) {
327	trans, err := f.delegate.GetTransport(orig)
328	if err == nil {
329		PropagateTConfiguration(orig, f.cfg)
330		PropagateTConfiguration(trans, f.cfg)
331	}
332	return trans, err
333}
334
335func (f *tTransportFactoryConf) SetTConfiguration(cfg *TConfiguration) {
336	PropagateTConfiguration(f.delegate, f.cfg)
337	f.cfg = cfg
338}
339
340// TTransportFactoryConf wraps a TTransportFactory to propagate
341// TConfiguration on the factory's GetTransport calls.
342func TTransportFactoryConf(delegate TTransportFactory, conf *TConfiguration) TTransportFactory {
343	return &tTransportFactoryConf{
344		delegate: delegate,
345		cfg:      conf,
346	}
347}
348
349type tProtocolFactoryConf struct {
350	delegate TProtocolFactory
351	cfg      *TConfiguration
352}
353
354func (f *tProtocolFactoryConf) GetProtocol(trans TTransport) TProtocol {
355	proto := f.delegate.GetProtocol(trans)
356	PropagateTConfiguration(trans, f.cfg)
357	PropagateTConfiguration(proto, f.cfg)
358	return proto
359}
360
361func (f *tProtocolFactoryConf) SetTConfiguration(cfg *TConfiguration) {
362	PropagateTConfiguration(f.delegate, f.cfg)
363	f.cfg = cfg
364}
365
366// TProtocolFactoryConf wraps a TProtocolFactory to propagate
367// TConfiguration on the factory's GetProtocol calls.
368func TProtocolFactoryConf(delegate TProtocolFactory, conf *TConfiguration) TProtocolFactory {
369	return &tProtocolFactoryConf{
370		delegate: delegate,
371		cfg:      conf,
372	}
373}
374
375var (
376	_ TConfigurationSetter = (*tTransportFactoryConf)(nil)
377	_ TConfigurationSetter = (*tProtocolFactoryConf)(nil)
378)
379