1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"context"
24	"io"
25	"strings"
26	"testing"
27	"testing/iotest"
28)
29
30func TestFramedTransport(t *testing.T) {
31	trans := NewTFramedTransport(NewTMemoryBuffer())
32	TransportTest(t, trans, trans)
33}
34
35func TestTFramedTransportReuseTransport(t *testing.T) {
36	const (
37		content = "Hello, world!"
38		n       = 10
39	)
40	trans := NewTMemoryBuffer()
41	reader := NewTFramedTransport(trans)
42	writer := NewTFramedTransport(trans)
43
44	t.Run("pair", func(t *testing.T) {
45		for i := 0; i < n; i++ {
46			// write
47			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
48				t.Fatalf("Failed to write on #%d: %v", i, err)
49			}
50			if err := writer.Flush(context.Background()); err != nil {
51				t.Fatalf("Failed to flush on #%d: %v", i, err)
52			}
53
54			// read
55			read, err := io.ReadAll(iotest.OneByteReader(reader))
56			if err != nil {
57				t.Errorf("Failed to read on #%d: %v", i, err)
58			}
59			if string(read) != content {
60				t.Errorf("Read #%d: want %q, got %q", i, content, read)
61			}
62		}
63	})
64
65	t.Run("batched", func(t *testing.T) {
66		// write
67		for i := 0; i < n; i++ {
68			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
69				t.Fatalf("Failed to write on #%d: %v", i, err)
70			}
71			if err := writer.Flush(context.Background()); err != nil {
72				t.Fatalf("Failed to flush on #%d: %v", i, err)
73			}
74		}
75
76		// read
77		for i := 0; i < n; i++ {
78			const (
79				size = len(content)
80			)
81			var buf []byte
82			var err error
83			if i%2 == 0 {
84				// on even calls, use OneByteReader to make
85				// sure that small reads are fine
86				buf, err = io.ReadAll(io.LimitReader(iotest.OneByteReader(reader), int64(size)))
87			} else {
88				// on odd calls, make sure that we don't read
89				// more than written per frame
90				buf = make([]byte, size*2)
91				var n int
92				n, err = reader.Read(buf)
93				buf = buf[:n]
94			}
95			if err != nil {
96				t.Errorf("Failed to read on #%d: %v", i, err)
97			}
98			if string(buf) != content {
99				t.Errorf("Read #%d: want %q, got %q", i, content, buf)
100			}
101		}
102	})
103}
104