1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef THRIFT_PY_COMPACT_H
21 #define THRIFT_PY_COMPACT_H
22 
23 #include <Python.h>
24 #include "ext/protocol.h"
25 #include "ext/endian.h"
26 #include <stdint.h>
27 #include <stack>
28 
29 namespace apache {
30 namespace thrift {
31 namespace py {
32 
33 class CompactProtocol : public ProtocolBase<CompactProtocol> {
34 public:
CompactProtocol()35   CompactProtocol() { readBool_.exists = false; }
36 
~CompactProtocol()37   virtual ~CompactProtocol() {}
38 
writeI8(int8_t val)39   void writeI8(int8_t val) { writeBuffer(reinterpret_cast<char*>(&val), 1); }
40 
writeI16(int16_t val)41   void writeI16(int16_t val) { writeVarint(toZigZag(val)); }
42 
writeI32(int32_t val)43   int writeI32(int32_t val) { return writeVarint(toZigZag(val)); }
44 
writeI64(int64_t val)45   void writeI64(int64_t val) { writeVarint64(toZigZag64(val)); }
46 
writeDouble(double dub)47   void writeDouble(double dub) {
48     union {
49       double f;
50       int64_t t;
51     } transfer;
52     transfer.f = htolell(dub);
53     writeBuffer(reinterpret_cast<char*>(&transfer.t), sizeof(int64_t));
54   }
55 
writeBool(int v)56   void writeBool(int v) { writeByte(static_cast<uint8_t>(v ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE)); }
57 
writeString(PyObject * value,int32_t len)58   void writeString(PyObject* value, int32_t len) {
59     writeVarint(len);
60     writeBuffer(PyBytes_AS_STRING(value), len);
61   }
62 
writeListBegin(PyObject * value,const SetListTypeArgs & args,int32_t len)63   bool writeListBegin(PyObject* value, const SetListTypeArgs& args, int32_t len) {
64     int ctype = toCompactType(args.element_type);
65     if (len <= 14) {
66       writeByte(static_cast<uint8_t>(len << 4 | ctype));
67     } else {
68       writeByte(0xf0 | ctype);
69       writeVarint(len);
70     }
71     return true;
72   }
73 
writeMapBegin(PyObject * value,const MapTypeArgs & args,int32_t len)74   bool writeMapBegin(PyObject* value, const MapTypeArgs& args, int32_t len) {
75     if (len == 0) {
76       writeByte(0);
77       return true;
78     }
79     int ctype = toCompactType(args.ktag) << 4 | toCompactType(args.vtag);
80     writeVarint(len);
81     writeByte(ctype);
82     return true;
83   }
84 
writeStructBegin()85   bool writeStructBegin() {
86     writeTags_.push(0);
87     return true;
88   }
writeStructEnd()89   bool writeStructEnd() {
90     writeTags_.pop();
91     return true;
92   }
93 
writeField(PyObject * value,const StructItemSpec & spec)94   bool writeField(PyObject* value, const StructItemSpec& spec) {
95     if (spec.type == T_BOOL) {
96       doWriteFieldBegin(spec, PyObject_IsTrue(value) ? CT_BOOLEAN_TRUE : CT_BOOLEAN_FALSE);
97       return true;
98     } else {
99       doWriteFieldBegin(spec, toCompactType(spec.type));
100       return encodeValue(value, spec.type, spec.typeargs);
101     }
102   }
103 
writeFieldStop()104   void writeFieldStop() { writeByte(0); }
105 
readBool(bool & val)106   bool readBool(bool& val) {
107     if (readBool_.exists) {
108       readBool_.exists = false;
109       val = readBool_.value;
110       return true;
111     }
112     char* buf;
113     if (!readBytes(&buf, 1)) {
114       return false;
115     }
116     val = buf[0] == CT_BOOLEAN_TRUE;
117     return true;
118   }
readI8(int8_t & val)119   bool readI8(int8_t& val) {
120     char* buf;
121     if (!readBytes(&buf, 1)) {
122       return false;
123     }
124     val = buf[0];
125     return true;
126   }
127 
readI16(int16_t & val)128   bool readI16(int16_t& val) {
129     uint16_t uval;
130     if (readVarint<uint16_t, 3>(uval)) {
131       val = fromZigZag<int16_t, uint16_t>(uval);
132       return true;
133     }
134     return false;
135   }
136 
readI32(int32_t & val)137   bool readI32(int32_t& val) {
138     uint32_t uval;
139     if (readVarint<uint32_t, 5>(uval)) {
140       val = fromZigZag<int32_t, uint32_t>(uval);
141       return true;
142     }
143     return false;
144   }
145 
readI64(int64_t & val)146   bool readI64(int64_t& val) {
147     uint64_t uval;
148     if (readVarint<uint64_t, 10>(uval)) {
149       val = fromZigZag<int64_t, uint64_t>(uval);
150       return true;
151     }
152     return false;
153   }
154 
readDouble(double & val)155   bool readDouble(double& val) {
156     union {
157       int64_t f;
158       double t;
159     } transfer;
160 
161     char* buf;
162     if (!readBytes(&buf, 8)) {
163       return false;
164     }
165     memcpy(&transfer.f, buf, sizeof(int64_t));
166     transfer.f = letohll(transfer.f);
167     val = transfer.t;
168     return true;
169   }
170 
readString(char ** buf)171   int32_t readString(char** buf) {
172     uint32_t len;
173     if (!readVarint<uint32_t, 5>(len) || !checkLengthLimit(len, stringLimit())) {
174       return -1;
175     }
176     if (len == 0) {
177       return 0;
178     }
179     if (!readBytes(buf, len)) {
180       return -1;
181     }
182     return len;
183   }
184 
readListBegin(TType & etype)185   int32_t readListBegin(TType& etype) {
186     uint8_t b;
187     if (!readByte(b)) {
188       return -1;
189     }
190     etype = getTType(b & 0xf);
191     if (etype == -1) {
192       return -1;
193     }
194     uint32_t len = (b >> 4) & 0xf;
195     if (len == 15 && !readVarint<uint32_t, 5>(len)) {
196       return -1;
197     }
198     if (!checkLengthLimit(len, containerLimit())) {
199       return -1;
200     }
201     return len;
202   }
203 
readMapBegin(TType & ktype,TType & vtype)204   int32_t readMapBegin(TType& ktype, TType& vtype) {
205     uint32_t len;
206     if (!readVarint<uint32_t, 5>(len) || !checkLengthLimit(len, containerLimit())) {
207       return -1;
208     }
209     if (len != 0) {
210       uint8_t kvType;
211       if (!readByte(kvType)) {
212         return -1;
213       }
214       ktype = getTType(kvType >> 4);
215       vtype = getTType(kvType & 0xf);
216       if (ktype == -1 || vtype == -1) {
217         return -1;
218       }
219     }
220     return len;
221   }
222 
readStructBegin()223   bool readStructBegin() {
224     readTags_.push(0);
225     return true;
226   }
readStructEnd()227   bool readStructEnd() {
228     readTags_.pop();
229     return true;
230   }
231   bool readFieldBegin(TType& type, int16_t& tag);
232 
skipBool()233   bool skipBool() {
234     bool val;
235     return readBool(val);
236   }
237 #define SKIPBYTES(n)                                                                               \
238   do {                                                                                             \
239     if (!readBytes(&dummy_buf_, (n))) {                                                            \
240       return false;                                                                                \
241     }                                                                                              \
242     return true;                                                                                   \
243   } while (0)
skipByte()244   bool skipByte() { SKIPBYTES(1); }
skipDouble()245   bool skipDouble() { SKIPBYTES(8); }
skipI16()246   bool skipI16() {
247     int16_t val;
248     return readI16(val);
249   }
skipI32()250   bool skipI32() {
251     int32_t val;
252     return readI32(val);
253   }
skipI64()254   bool skipI64() {
255     int64_t val;
256     return readI64(val);
257   }
skipString()258   bool skipString() {
259     uint32_t len;
260     if (!readVarint<uint32_t, 5>(len)) {
261       return false;
262     }
263     SKIPBYTES(len);
264   }
265 #undef SKIPBYTES
266 
267 private:
268   enum Types {
269     CT_STOP = 0x00,
270     CT_BOOLEAN_TRUE = 0x01,
271     CT_BOOLEAN_FALSE = 0x02,
272     CT_BYTE = 0x03,
273     CT_I16 = 0x04,
274     CT_I32 = 0x05,
275     CT_I64 = 0x06,
276     CT_DOUBLE = 0x07,
277     CT_BINARY = 0x08,
278     CT_LIST = 0x09,
279     CT_SET = 0x0A,
280     CT_MAP = 0x0B,
281     CT_STRUCT = 0x0C
282   };
283 
284   static const uint8_t TTypeToCType[];
285 
286   TType getTType(uint8_t type);
287 
toCompactType(TType type)288   int toCompactType(TType type) {
289     int i = static_cast<int>(type);
290     return i < 16 ? TTypeToCType[i] : -1;
291   }
292 
toZigZag(int32_t val)293   uint32_t toZigZag(int32_t val) { return (val >> 31) ^ (val << 1); }
294 
toZigZag64(int64_t val)295   uint64_t toZigZag64(int64_t val) { return (val >> 63) ^ (val << 1); }
296 
writeVarint(uint32_t val)297   int writeVarint(uint32_t val) {
298     int cnt = 1;
299     while (val & ~0x7fU) {
300       writeByte(static_cast<char>((val & 0x7fU) | 0x80U));
301       val >>= 7;
302       ++cnt;
303     }
304     writeByte(static_cast<char>(val));
305     return cnt;
306   }
307 
writeVarint64(uint64_t val)308   int writeVarint64(uint64_t val) {
309     int cnt = 1;
310     while (val & ~0x7fULL) {
311       writeByte(static_cast<char>((val & 0x7fULL) | 0x80ULL));
312       val >>= 7;
313       ++cnt;
314     }
315     writeByte(static_cast<char>(val));
316     return cnt;
317   }
318 
319   template <typename T, int Max>
readVarint(T & result)320   bool readVarint(T& result) {
321     uint8_t b;
322     T val = 0;
323     int shift = 0;
324     for (int i = 0; i < Max; ++i) {
325       if (!readByte(b)) {
326         return false;
327       }
328       if (b & 0x80) {
329         val |= static_cast<T>(b & 0x7f) << shift;
330       } else {
331         val |= static_cast<T>(b) << shift;
332         result = val;
333         return true;
334       }
335       shift += 7;
336     }
337     PyErr_Format(PyExc_OverflowError, "varint exceeded %d bytes", Max);
338     return false;
339   }
340 
341   template <typename S, typename U>
fromZigZag(U val)342   S fromZigZag(U val) {
343     return (val >> 1) ^ static_cast<U>(-static_cast<S>(val & 1));
344   }
345 
doWriteFieldBegin(const StructItemSpec & spec,int ctype)346   void doWriteFieldBegin(const StructItemSpec& spec, int ctype) {
347     int diff = spec.tag - writeTags_.top();
348     if (diff > 0 && diff <= 15) {
349       writeByte(static_cast<uint8_t>(diff << 4 | ctype));
350     } else {
351       writeByte(static_cast<uint8_t>(ctype));
352       writeI16(spec.tag);
353     }
354     writeTags_.top() = spec.tag;
355   }
356 
357   std::stack<int> writeTags_;
358   std::stack<int> readTags_;
359   struct {
360     bool exists;
361     bool value;
362   } readBool_;
363   char* dummy_buf_;
364 };
365 }
366 }
367 }
368 #endif // THRIFT_PY_COMPACT_H
369