1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"context"
24	"errors"
25)
26
27// THeaderProtocol is a thrift protocol that implements THeader:
28// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
29//
30// It supports either binary or compact protocol as the wrapped protocol.
31//
32// Most of the THeader handlings are happening inside THeaderTransport.
33type THeaderProtocol struct {
34	transport *THeaderTransport
35
36	// Will be initialized on first read/write.
37	protocol TProtocol
38
39	cfg *TConfiguration
40}
41
42// Deprecated: Use NewTHeaderProtocolConf instead.
43func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
44	return newTHeaderProtocolConf(trans, &TConfiguration{
45		noPropagation: true,
46	})
47}
48
49// NewTHeaderProtocolConf creates a new THeaderProtocol from the underlying
50// transport with given TConfiguration.
51//
52// The passed in transport will be wrapped with THeaderTransport.
53//
54// Note that THeaderTransport handles frame and zlib by itself,
55// so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
56// instead of rich transports like TZlibTransport or TFramedTransport.
57func NewTHeaderProtocolConf(trans TTransport, conf *TConfiguration) *THeaderProtocol {
58	return newTHeaderProtocolConf(trans, conf)
59}
60
61func newTHeaderProtocolConf(trans TTransport, cfg *TConfiguration) *THeaderProtocol {
62	t := NewTHeaderTransportConf(trans, cfg)
63	p, _ := t.cfg.GetTHeaderProtocolID().GetProtocol(t)
64	PropagateTConfiguration(p, cfg)
65	return &THeaderProtocol{
66		transport: t,
67		protocol:  p,
68		cfg:       cfg,
69	}
70}
71
72type tHeaderProtocolFactory struct {
73	cfg *TConfiguration
74}
75
76func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
77	return newTHeaderProtocolConf(trans, f.cfg)
78}
79
80func (f *tHeaderProtocolFactory) SetTConfiguration(cfg *TConfiguration) {
81	f.cfg = cfg
82}
83
84// Deprecated: Use NewTHeaderProtocolFactoryConf instead.
85func NewTHeaderProtocolFactory() TProtocolFactory {
86	return NewTHeaderProtocolFactoryConf(&TConfiguration{
87		noPropagation: true,
88	})
89}
90
91// NewTHeaderProtocolFactoryConf creates a factory for THeader with given
92// TConfiguration.
93func NewTHeaderProtocolFactoryConf(conf *TConfiguration) TProtocolFactory {
94	return tHeaderProtocolFactory{
95		cfg: conf,
96	}
97}
98
99// Transport returns the underlying transport.
100//
101// It's guaranteed to be of type *THeaderTransport.
102func (p *THeaderProtocol) Transport() TTransport {
103	return p.transport
104}
105
106// GetReadHeaders returns the THeaderMap read from transport.
107func (p *THeaderProtocol) GetReadHeaders() THeaderMap {
108	return p.transport.GetReadHeaders()
109}
110
111// SetWriteHeader sets a header for write.
112func (p *THeaderProtocol) SetWriteHeader(key, value string) {
113	p.transport.SetWriteHeader(key, value)
114}
115
116// ClearWriteHeaders clears all write headers previously set.
117func (p *THeaderProtocol) ClearWriteHeaders() {
118	p.transport.ClearWriteHeaders()
119}
120
121// AddTransform add a transform for writing.
122func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error {
123	return p.transport.AddTransform(transform)
124}
125
126func (p *THeaderProtocol) Flush(ctx context.Context) error {
127	return p.transport.Flush(ctx)
128}
129
130func (p *THeaderProtocol) WriteMessageBegin(ctx context.Context, name string, typeID TMessageType, seqID int32) error {
131	newProto, err := p.transport.Protocol().GetProtocol(p.transport)
132	if err != nil {
133		return err
134	}
135	PropagateTConfiguration(newProto, p.cfg)
136	p.protocol = newProto
137	p.transport.SequenceID = seqID
138	return p.protocol.WriteMessageBegin(ctx, name, typeID, seqID)
139}
140
141func (p *THeaderProtocol) WriteMessageEnd(ctx context.Context) error {
142	if err := p.protocol.WriteMessageEnd(ctx); err != nil {
143		return err
144	}
145	return p.transport.Flush(ctx)
146}
147
148func (p *THeaderProtocol) WriteStructBegin(ctx context.Context, name string) error {
149	return p.protocol.WriteStructBegin(ctx, name)
150}
151
152func (p *THeaderProtocol) WriteStructEnd(ctx context.Context) error {
153	return p.protocol.WriteStructEnd(ctx)
154}
155
156func (p *THeaderProtocol) WriteFieldBegin(ctx context.Context, name string, typeID TType, id int16) error {
157	return p.protocol.WriteFieldBegin(ctx, name, typeID, id)
158}
159
160func (p *THeaderProtocol) WriteFieldEnd(ctx context.Context) error {
161	return p.protocol.WriteFieldEnd(ctx)
162}
163
164func (p *THeaderProtocol) WriteFieldStop(ctx context.Context) error {
165	return p.protocol.WriteFieldStop(ctx)
166}
167
168func (p *THeaderProtocol) WriteMapBegin(ctx context.Context, keyType TType, valueType TType, size int) error {
169	return p.protocol.WriteMapBegin(ctx, keyType, valueType, size)
170}
171
172func (p *THeaderProtocol) WriteMapEnd(ctx context.Context) error {
173	return p.protocol.WriteMapEnd(ctx)
174}
175
176func (p *THeaderProtocol) WriteListBegin(ctx context.Context, elemType TType, size int) error {
177	return p.protocol.WriteListBegin(ctx, elemType, size)
178}
179
180func (p *THeaderProtocol) WriteListEnd(ctx context.Context) error {
181	return p.protocol.WriteListEnd(ctx)
182}
183
184func (p *THeaderProtocol) WriteSetBegin(ctx context.Context, elemType TType, size int) error {
185	return p.protocol.WriteSetBegin(ctx, elemType, size)
186}
187
188func (p *THeaderProtocol) WriteSetEnd(ctx context.Context) error {
189	return p.protocol.WriteSetEnd(ctx)
190}
191
192func (p *THeaderProtocol) WriteBool(ctx context.Context, value bool) error {
193	return p.protocol.WriteBool(ctx, value)
194}
195
196func (p *THeaderProtocol) WriteByte(ctx context.Context, value int8) error {
197	return p.protocol.WriteByte(ctx, value)
198}
199
200func (p *THeaderProtocol) WriteI16(ctx context.Context, value int16) error {
201	return p.protocol.WriteI16(ctx, value)
202}
203
204func (p *THeaderProtocol) WriteI32(ctx context.Context, value int32) error {
205	return p.protocol.WriteI32(ctx, value)
206}
207
208func (p *THeaderProtocol) WriteI64(ctx context.Context, value int64) error {
209	return p.protocol.WriteI64(ctx, value)
210}
211
212func (p *THeaderProtocol) WriteDouble(ctx context.Context, value float64) error {
213	return p.protocol.WriteDouble(ctx, value)
214}
215
216func (p *THeaderProtocol) WriteString(ctx context.Context, value string) error {
217	return p.protocol.WriteString(ctx, value)
218}
219
220func (p *THeaderProtocol) WriteBinary(ctx context.Context, value []byte) error {
221	return p.protocol.WriteBinary(ctx, value)
222}
223
224func (p *THeaderProtocol) WriteUUID(ctx context.Context, value Tuuid) error {
225	return p.protocol.WriteUUID(ctx, value)
226}
227
228// ReadFrame calls underlying THeaderTransport's ReadFrame function.
229func (p *THeaderProtocol) ReadFrame(ctx context.Context) error {
230	return p.transport.ReadFrame(ctx)
231}
232
233func (p *THeaderProtocol) ReadMessageBegin(ctx context.Context) (name string, typeID TMessageType, seqID int32, err error) {
234	if err = p.transport.ReadFrame(ctx); err != nil {
235		return
236	}
237
238	var newProto TProtocol
239	newProto, err = p.transport.Protocol().GetProtocol(p.transport)
240	if err != nil {
241		var tAppExc TApplicationException
242		if !errors.As(err, &tAppExc) {
243			return
244		}
245		if e := p.protocol.WriteMessageBegin(ctx, "", EXCEPTION, seqID); e != nil {
246			return
247		}
248		if e := tAppExc.Write(ctx, p.protocol); e != nil {
249			return
250		}
251		if e := p.protocol.WriteMessageEnd(ctx); e != nil {
252			return
253		}
254		if e := p.transport.Flush(ctx); e != nil {
255			return
256		}
257		return
258	}
259	PropagateTConfiguration(newProto, p.cfg)
260	p.protocol = newProto
261
262	return p.protocol.ReadMessageBegin(ctx)
263}
264
265func (p *THeaderProtocol) ReadMessageEnd(ctx context.Context) error {
266	return p.protocol.ReadMessageEnd(ctx)
267}
268
269func (p *THeaderProtocol) ReadStructBegin(ctx context.Context) (name string, err error) {
270	return p.protocol.ReadStructBegin(ctx)
271}
272
273func (p *THeaderProtocol) ReadStructEnd(ctx context.Context) error {
274	return p.protocol.ReadStructEnd(ctx)
275}
276
277func (p *THeaderProtocol) ReadFieldBegin(ctx context.Context) (name string, typeID TType, id int16, err error) {
278	return p.protocol.ReadFieldBegin(ctx)
279}
280
281func (p *THeaderProtocol) ReadFieldEnd(ctx context.Context) error {
282	return p.protocol.ReadFieldEnd(ctx)
283}
284
285func (p *THeaderProtocol) ReadMapBegin(ctx context.Context) (keyType TType, valueType TType, size int, err error) {
286	return p.protocol.ReadMapBegin(ctx)
287}
288
289func (p *THeaderProtocol) ReadMapEnd(ctx context.Context) error {
290	return p.protocol.ReadMapEnd(ctx)
291}
292
293func (p *THeaderProtocol) ReadListBegin(ctx context.Context) (elemType TType, size int, err error) {
294	return p.protocol.ReadListBegin(ctx)
295}
296
297func (p *THeaderProtocol) ReadListEnd(ctx context.Context) error {
298	return p.protocol.ReadListEnd(ctx)
299}
300
301func (p *THeaderProtocol) ReadSetBegin(ctx context.Context) (elemType TType, size int, err error) {
302	return p.protocol.ReadSetBegin(ctx)
303}
304
305func (p *THeaderProtocol) ReadSetEnd(ctx context.Context) error {
306	return p.protocol.ReadSetEnd(ctx)
307}
308
309func (p *THeaderProtocol) ReadBool(ctx context.Context) (value bool, err error) {
310	return p.protocol.ReadBool(ctx)
311}
312
313func (p *THeaderProtocol) ReadByte(ctx context.Context) (value int8, err error) {
314	return p.protocol.ReadByte(ctx)
315}
316
317func (p *THeaderProtocol) ReadI16(ctx context.Context) (value int16, err error) {
318	return p.protocol.ReadI16(ctx)
319}
320
321func (p *THeaderProtocol) ReadI32(ctx context.Context) (value int32, err error) {
322	return p.protocol.ReadI32(ctx)
323}
324
325func (p *THeaderProtocol) ReadI64(ctx context.Context) (value int64, err error) {
326	return p.protocol.ReadI64(ctx)
327}
328
329func (p *THeaderProtocol) ReadDouble(ctx context.Context) (value float64, err error) {
330	return p.protocol.ReadDouble(ctx)
331}
332
333func (p *THeaderProtocol) ReadString(ctx context.Context) (value string, err error) {
334	return p.protocol.ReadString(ctx)
335}
336
337func (p *THeaderProtocol) ReadBinary(ctx context.Context) (value []byte, err error) {
338	return p.protocol.ReadBinary(ctx)
339}
340
341func (p *THeaderProtocol) ReadUUID(ctx context.Context) (value Tuuid, err error) {
342	return p.protocol.ReadUUID(ctx)
343}
344
345func (p *THeaderProtocol) Skip(ctx context.Context, fieldType TType) error {
346	return p.protocol.Skip(ctx, fieldType)
347}
348
349// SetTConfiguration implements TConfigurationSetter.
350func (p *THeaderProtocol) SetTConfiguration(cfg *TConfiguration) {
351	PropagateTConfiguration(p.transport, cfg)
352	PropagateTConfiguration(p.protocol, cfg)
353	p.cfg = cfg
354}
355
356var (
357	_ TConfigurationSetter = (*tHeaderProtocolFactory)(nil)
358	_ TConfigurationSetter = (*THeaderProtocol)(nil)
359)
360