1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20module Thrift
21  class BinaryProtocol < BaseProtocol
22    VERSION_MASK = 0xffff0000
23    VERSION_1 = 0x80010000
24    TYPE_MASK = 0x000000ff
25
26    attr_reader :strict_read, :strict_write
27
28    def initialize(trans, strict_read=true, strict_write=true)
29      super(trans)
30      @strict_read = strict_read
31      @strict_write = strict_write
32
33      # Pre-allocated read buffer for fixed-size read methods. Needs to be at least 8 bytes long for
34      # read_i64() and read_double().
35      @rbuf = Bytes.empty_byte_buffer(8)
36    end
37
38    def write_message_begin(name, type, seqid)
39      # this is necessary because we added (needed) bounds checking to
40      # write_i32, and 0x80010000 is too big for that.
41      if strict_write
42        write_i16(VERSION_1 >> 16)
43        write_i16(type)
44        write_string(name)
45        write_i32(seqid)
46      else
47        write_string(name)
48        write_byte(type)
49        write_i32(seqid)
50      end
51    end
52
53    def write_struct_begin(name); nil; end
54
55    def write_field_begin(name, type, id)
56      write_byte(type)
57      write_i16(id)
58    end
59
60    def write_field_stop
61      write_byte(Thrift::Types::STOP)
62    end
63
64    def write_map_begin(ktype, vtype, size)
65      write_byte(ktype)
66      write_byte(vtype)
67      write_i32(size)
68    end
69
70    def write_list_begin(etype, size)
71      write_byte(etype)
72      write_i32(size)
73    end
74
75    def write_set_begin(etype, size)
76      write_byte(etype)
77      write_i32(size)
78    end
79
80    def write_bool(bool)
81      write_byte(bool ? 1 : 0)
82    end
83
84    def write_byte(byte)
85      raise RangeError if byte < -2**31 || byte >= 2**32
86      trans.write([byte].pack('c'))
87    end
88
89    def write_i16(i16)
90      trans.write([i16].pack('n'))
91    end
92
93    def write_i32(i32)
94      raise RangeError if i32 < -2**31 || i32 >= 2**31
95      trans.write([i32].pack('N'))
96    end
97
98    def write_i64(i64)
99      raise RangeError if i64 < -2**63 || i64 >= 2**64
100      hi = i64 >> 32
101      lo = i64 & 0xffffffff
102      trans.write([hi, lo].pack('N2'))
103    end
104
105    def write_double(dub)
106      trans.write([dub].pack('G'))
107    end
108
109    def write_string(str)
110      buf = Bytes.convert_to_utf8_byte_buffer(str)
111      write_binary(buf)
112    end
113
114    def write_binary(buf)
115      write_i32(buf.bytesize)
116      trans.write(buf)
117    end
118
119    def read_message_begin
120      version = read_i32
121      if version < 0
122        if (version & VERSION_MASK != VERSION_1)
123          raise ProtocolException.new(ProtocolException::BAD_VERSION, 'Missing version identifier')
124        end
125        type = version & TYPE_MASK
126        name = read_string
127        seqid = read_i32
128        [name, type, seqid]
129      else
130        if strict_read
131          raise ProtocolException.new(ProtocolException::BAD_VERSION, 'No version identifier, old protocol client?')
132        end
133        name = trans.read_all(version)
134        type = read_byte
135        seqid = read_i32
136        [name, type, seqid]
137      end
138    end
139
140    def read_struct_begin; nil; end
141
142    def read_field_begin
143      type = read_byte
144      if (type == Types::STOP)
145        [nil, type, 0]
146      else
147        id = read_i16
148        [nil, type, id]
149      end
150    end
151
152    def read_map_begin
153      ktype = read_byte
154      vtype = read_byte
155      size = read_i32
156      [ktype, vtype, size]
157    end
158
159    def read_list_begin
160      etype = read_byte
161      size = read_i32
162      [etype, size]
163    end
164
165    def read_set_begin
166      etype = read_byte
167      size = read_i32
168      [etype, size]
169    end
170
171    def read_bool
172      byte = read_byte
173      byte != 0
174    end
175
176    def read_byte
177      val = trans.read_byte
178      if (val > 0x7f)
179        val = 0 - ((val - 1) ^ 0xff)
180      end
181      val
182    end
183
184    def read_i16
185      trans.read_into_buffer(@rbuf, 2)
186      val, = @rbuf.unpack('n')
187      if (val > 0x7fff)
188        val = 0 - ((val - 1) ^ 0xffff)
189      end
190      val
191    end
192
193    def read_i32
194      trans.read_into_buffer(@rbuf, 4)
195      val, = @rbuf.unpack('N')
196      if (val > 0x7fffffff)
197        val = 0 - ((val - 1) ^ 0xffffffff)
198      end
199      val
200    end
201
202    def read_i64
203      trans.read_into_buffer(@rbuf, 8)
204      hi, lo = @rbuf.unpack('N2')
205      if (hi > 0x7fffffff)
206        hi ^= 0xffffffff
207        lo ^= 0xffffffff
208        0 - (hi << 32) - lo - 1
209      else
210        (hi << 32) + lo
211      end
212    end
213
214    def read_double
215      trans.read_into_buffer(@rbuf, 8)
216      val = @rbuf.unpack('G').first
217      val
218    end
219
220    def read_string
221      buffer = read_binary
222      Bytes.convert_to_string(buffer)
223    end
224
225    def read_binary
226      size = read_i32
227      trans.read_all(size)
228    end
229
230    def to_s
231      "binary(#{super.to_s})"
232    end
233  end
234
235  class BinaryProtocolFactory < BaseProtocolFactory
236    def get_protocol(trans)
237      return Thrift::BinaryProtocol.new(trans)
238    end
239
240    def to_s
241      "binary"
242    end
243  end
244end
245