1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <thrift/transport/THeaderTransport.h>
21 #include <thrift/TApplicationException.h>
22 #include <thrift/protocol/TProtocolTypes.h>
23 #include <thrift/protocol/TBinaryProtocol.h>
24 #include <thrift/protocol/TCompactProtocol.h>
25
26 #include <limits>
27 #include <utility>
28 #include <string>
29 #include <string.h>
30 #include <zlib.h>
31
32 using std::map;
33 using std::string;
34 using std::vector;
35
36 namespace apache {
37 namespace thrift {
38
39 using std::shared_ptr;
40
41 namespace transport {
42
43 using namespace apache::thrift::protocol;
44 using apache::thrift::protocol::TBinaryProtocol;
45
readSlow(uint8_t * buf,uint32_t len)46 uint32_t THeaderTransport::readSlow(uint8_t* buf, uint32_t len) {
47 if (clientType == THRIFT_UNFRAMED_BINARY || clientType == THRIFT_UNFRAMED_COMPACT) {
48 return transport_->read(buf, len);
49 }
50
51 return TFramedTransport::readSlow(buf, len);
52 }
53
getProtocolId() const54 uint16_t THeaderTransport::getProtocolId() const {
55 if (clientType == THRIFT_HEADER_CLIENT_TYPE) {
56 return protoId;
57 } else if (clientType == THRIFT_UNFRAMED_COMPACT || clientType == THRIFT_FRAMED_COMPACT) {
58 return T_COMPACT_PROTOCOL;
59 } else {
60 return T_BINARY_PROTOCOL; // Assume other transports use TBinary
61 }
62 }
63
ensureReadBuffer(uint32_t sz)64 void THeaderTransport::ensureReadBuffer(uint32_t sz) {
65 if (sz > rBufSize_) {
66 rBuf_.reset(new uint8_t[sz]);
67 rBufSize_ = sz;
68 }
69 }
70
readFrame()71 bool THeaderTransport::readFrame() {
72 // szN is network byte order of sz
73 uint32_t szN;
74 uint32_t sz;
75
76 // Read the size of the next frame.
77 // We can't use readAll(&sz, sizeof(sz)), since that always throws an
78 // exception on EOF. We want to throw an exception only if EOF occurs after
79 // partial size data.
80 uint32_t sizeBytesRead = 0;
81 while (sizeBytesRead < sizeof(szN)) {
82 uint8_t* szp = reinterpret_cast<uint8_t*>(&szN) + sizeBytesRead;
83 uint32_t bytesRead = transport_->read(szp, sizeof(szN) - sizeBytesRead);
84 if (bytesRead == 0) {
85 if (sizeBytesRead == 0) {
86 // EOF before any data was read.
87 return false;
88 } else {
89 // EOF after a partial frame header. Raise an exception.
90 throw TTransportException(TTransportException::END_OF_FILE,
91 "No more data to read after "
92 "partial frame header.");
93 }
94 }
95 sizeBytesRead += bytesRead;
96 }
97
98 sz = ntohl(szN);
99
100 ensureReadBuffer(4);
101
102 if ((sz & TBinaryProtocol::VERSION_MASK) == (uint32_t)TBinaryProtocol::VERSION_1) {
103 // unframed
104 clientType = THRIFT_UNFRAMED_BINARY;
105 memcpy(rBuf_.get(), &szN, sizeof(szN));
106 setReadBuffer(rBuf_.get(), 4);
107 } else if (static_cast<int8_t>(sz >> 24) == TCompactProtocol::PROTOCOL_ID
108 && (static_cast<int8_t>(sz >> 16) & TCompactProtocol::VERSION_MASK)
109 == TCompactProtocol::VERSION_N) {
110 clientType = THRIFT_UNFRAMED_COMPACT;
111 memcpy(rBuf_.get(), &szN, sizeof(szN));
112 setReadBuffer(rBuf_.get(), 4);
113 } else {
114 // Could be header format or framed. Check next uint32
115 uint32_t magic_n;
116 uint32_t magic;
117
118 if (sz > MAX_FRAME_SIZE) {
119 throw TTransportException(TTransportException::CORRUPTED_DATA,
120 "Header transport frame is too large");
121 }
122
123 ensureReadBuffer(sz);
124
125 // We can use readAll here, because it would be an invalid frame otherwise
126 transport_->readAll(reinterpret_cast<uint8_t*>(&magic_n), sizeof(magic_n));
127 memcpy(rBuf_.get(), &magic_n, sizeof(magic_n));
128 magic = ntohl(magic_n);
129
130 if ((magic & TBinaryProtocol::VERSION_MASK) == (uint32_t)TBinaryProtocol::VERSION_1) {
131 // framed
132 clientType = THRIFT_FRAMED_BINARY;
133 transport_->readAll(rBuf_.get() + 4, sz - 4);
134 setReadBuffer(rBuf_.get(), sz);
135 } else if (static_cast<int8_t>(magic >> 24) == TCompactProtocol::PROTOCOL_ID
136 && (static_cast<int8_t>(magic >> 16) & TCompactProtocol::VERSION_MASK)
137 == TCompactProtocol::VERSION_N) {
138 clientType = THRIFT_FRAMED_COMPACT;
139 transport_->readAll(rBuf_.get() + 4, sz - 4);
140 setReadBuffer(rBuf_.get(), sz);
141 } else if (HEADER_MAGIC == (magic & HEADER_MASK)) {
142 if (sz < 10) {
143 throw TTransportException(TTransportException::CORRUPTED_DATA,
144 "Header transport frame is too small");
145 }
146
147 transport_->readAll(rBuf_.get() + 4, sz - 4);
148
149 // header format
150 clientType = THRIFT_HEADER_CLIENT_TYPE;
151 // flags
152 flags = magic & FLAGS_MASK;
153 // seqId
154 uint32_t seqId_n;
155 memcpy(&seqId_n, rBuf_.get() + 4, sizeof(seqId_n));
156 seqId = ntohl(seqId_n);
157 // header size
158 uint16_t headerSize_n;
159 memcpy(&headerSize_n, rBuf_.get() + 8, sizeof(headerSize_n));
160 uint16_t headerSize = ntohs(headerSize_n);
161 setReadBuffer(rBuf_.get(), sz);
162 readHeaderFormat(headerSize, sz);
163 } else {
164 clientType = THRIFT_UNKNOWN_CLIENT_TYPE;
165 throw TTransportException(TTransportException::BAD_ARGS,
166 "Could not detect client transport type");
167 }
168 }
169
170 return true;
171 }
172
173 /**
174 * Reads a string from ptr, taking care not to reach headerBoundary
175 * Advances ptr on success
176 *
177 * @param str output string
178 * @throws CORRUPTED_DATA if size of string exceeds boundary
179 */
readString(uint8_t * & ptr,string & str,uint8_t const * headerBoundary)180 void THeaderTransport::readString(uint8_t*& ptr,
181 /* out */ string& str,
182 uint8_t const* headerBoundary) {
183 int32_t strLen;
184
185 uint32_t bytes = readVarint32(ptr, &strLen, headerBoundary);
186 if (strLen > headerBoundary - ptr) {
187 throw TTransportException(TTransportException::CORRUPTED_DATA,
188 "Info header length exceeds header size");
189 }
190 ptr += bytes;
191 str.assign(reinterpret_cast<const char*>(ptr), strLen);
192 ptr += strLen;
193 }
194
readHeaderFormat(uint16_t headerSize,uint32_t sz)195 void THeaderTransport::readHeaderFormat(uint16_t headerSize, uint32_t sz) {
196 readTrans_.clear(); // Clear out any previous transforms.
197 readHeaders_.clear(); // Clear out any previous headers.
198
199 // skip over already processed magic(4), seqId(4), headerSize(2)
200 auto* ptr = reinterpret_cast<uint8_t*>(rBuf_.get() + 10);
201
202 // Catch integer overflow, check for reasonable header size
203 if (headerSize >= 16384) {
204 throw TTransportException(TTransportException::CORRUPTED_DATA,
205 "Header size is unreasonable");
206 }
207 headerSize *= 4;
208 const uint8_t* const headerBoundary = ptr + headerSize;
209 if (headerSize > sz) {
210 throw TTransportException(TTransportException::CORRUPTED_DATA,
211 "Header size is larger than frame");
212 }
213 uint8_t* data = ptr + headerSize;
214 ptr += readVarint16(ptr, &protoId, headerBoundary);
215 int16_t numTransforms;
216 ptr += readVarint16(ptr, &numTransforms, headerBoundary);
217
218 // For now all transforms consist of only the ID, not data.
219 for (int i = 0; i < numTransforms; i++) {
220 int32_t transId;
221 ptr += readVarint32(ptr, &transId, headerBoundary);
222
223 readTrans_.push_back(transId);
224 }
225
226 // Info headers
227 while (ptr < headerBoundary) {
228 int32_t infoId;
229 ptr += readVarint32(ptr, &infoId, headerBoundary);
230
231 if (infoId == 0) {
232 // header padding
233 break;
234 }
235 if (infoId >= infoIdType::END) {
236 // cannot handle infoId
237 break;
238 }
239 switch (infoId) {
240 case infoIdType::KEYVALUE:
241 // Process key-value headers
242 uint32_t numKVHeaders;
243 ptr += readVarint32(ptr, (int32_t*)&numKVHeaders, headerBoundary);
244 // continue until we reach (padded) end of packet
245 while (numKVHeaders-- && ptr < headerBoundary) {
246 // format: key; value
247 // both: length (varint32); value (string)
248 string key, value;
249 readString(ptr, key, headerBoundary);
250 // value
251 readString(ptr, value, headerBoundary);
252 // save to headers
253 readHeaders_[key] = value;
254 }
255 break;
256 }
257 }
258
259 // Untransform the data section. rBuf will contain result.
260 untransform(data, safe_numeric_cast<uint32_t>(static_cast<ptrdiff_t>(sz) - (data - rBuf_.get())));
261 }
262
untransform(uint8_t * ptr,uint32_t sz)263 void THeaderTransport::untransform(uint8_t* ptr, uint32_t sz) {
264 // Update the transform buffer size if needed
265 resizeTransformBuffer();
266
267 for (vector<uint16_t>::const_iterator it = readTrans_.begin(); it != readTrans_.end(); ++it) {
268 const uint16_t transId = *it;
269
270 if (transId == ZLIB_TRANSFORM) {
271 z_stream stream;
272 int err;
273
274 stream.next_in = ptr;
275 stream.avail_in = sz;
276
277 // Setting these to 0 means use the default free/alloc functions
278 stream.zalloc = (alloc_func)nullptr;
279 stream.zfree = (free_func)nullptr;
280 stream.opaque = (voidpf)nullptr;
281 err = inflateInit(&stream);
282 if (err != Z_OK) {
283 throw TApplicationException(TApplicationException::MISSING_RESULT,
284 "Error while zlib deflateInit");
285 }
286 stream.next_out = tBuf_.get();
287 stream.avail_out = tBufSize_;
288 err = inflate(&stream, Z_FINISH);
289 if (err != Z_STREAM_END || stream.avail_out == 0) {
290 throw TApplicationException(TApplicationException::MISSING_RESULT,
291 "Error while zlib deflate");
292 }
293 sz = stream.total_out;
294
295 err = inflateEnd(&stream);
296 if (err != Z_OK) {
297 throw TApplicationException(TApplicationException::MISSING_RESULT,
298 "Error while zlib deflateEnd");
299 }
300
301 memcpy(ptr, tBuf_.get(), sz);
302 } else {
303 throw TApplicationException(TApplicationException::MISSING_RESULT, "Unknown transform");
304 }
305 }
306
307 setReadBuffer(ptr, sz);
308 }
309
310 /**
311 * We may have updated the wBuf size, update the tBuf size to match.
312 * Should be called in transform.
313 *
314 * The buffer should be slightly larger than write buffer size due to
315 * compression transforms (that may slightly grow on small frame sizes)
316 */
resizeTransformBuffer(uint32_t additionalSize)317 void THeaderTransport::resizeTransformBuffer(uint32_t additionalSize) {
318 if (tBufSize_ < wBufSize_ + DEFAULT_BUFFER_SIZE) {
319 uint32_t new_size = wBufSize_ + DEFAULT_BUFFER_SIZE + additionalSize;
320 auto* new_buf = new uint8_t[new_size];
321 tBuf_.reset(new_buf);
322 tBufSize_ = new_size;
323 }
324 }
325
transform(uint8_t * ptr,uint32_t sz)326 void THeaderTransport::transform(uint8_t* ptr, uint32_t sz) {
327 // Update the transform buffer size if needed
328 resizeTransformBuffer();
329
330 for (vector<uint16_t>::const_iterator it = writeTrans_.begin(); it != writeTrans_.end(); ++it) {
331 const uint16_t transId = *it;
332
333 if (transId == ZLIB_TRANSFORM) {
334 z_stream stream;
335 int err;
336
337 stream.next_in = ptr;
338 stream.avail_in = sz;
339
340 stream.zalloc = (alloc_func)nullptr;
341 stream.zfree = (free_func)nullptr;
342 stream.opaque = (voidpf)nullptr;
343 err = deflateInit(&stream, Z_DEFAULT_COMPRESSION);
344 if (err != Z_OK) {
345 throw TTransportException(TTransportException::CORRUPTED_DATA,
346 "Error while zlib deflateInit");
347 }
348 uint32_t tbuf_size = 0;
349 while (err == Z_OK) {
350 resizeTransformBuffer(tbuf_size);
351
352 stream.next_out = tBuf_.get();
353 stream.avail_out = tBufSize_;
354 err = deflate(&stream, Z_FINISH);
355 tbuf_size += DEFAULT_BUFFER_SIZE;
356 }
357 sz = stream.total_out;
358
359 err = deflateEnd(&stream);
360 if (err != Z_OK) {
361 throw TTransportException(TTransportException::CORRUPTED_DATA,
362 "Error while zlib deflateEnd");
363 }
364
365 memcpy(ptr, tBuf_.get(), sz);
366 } else {
367 throw TTransportException(TTransportException::CORRUPTED_DATA, "Unknown transform");
368 }
369 }
370
371 wBase_ = wBuf_.get() + sz;
372 }
373
resetProtocol()374 void THeaderTransport::resetProtocol() {
375 // Set to anything except HTTP type so we don't flush again
376 clientType = THRIFT_HEADER_CLIENT_TYPE;
377
378 // Read the header and decide which protocol to go with
379 readFrame();
380 }
381
getWriteBytes()382 uint32_t THeaderTransport::getWriteBytes() {
383 return safe_numeric_cast<uint32_t>(wBase_ - wBuf_.get());
384 }
385
386 /**
387 * Writes a string to a byte buffer, as size (varint32) + string (non-null
388 * terminated)
389 * Automatically advances ptr to after the written portion
390 */
writeString(uint8_t * & ptr,const string & str)391 void THeaderTransport::writeString(uint8_t*& ptr, const string& str) {
392 auto strLen = safe_numeric_cast<int32_t>(str.length());
393 ptr += writeVarint32(strLen, ptr);
394 memcpy(ptr, str.c_str(), strLen); // no need to write \0
395 ptr += strLen;
396 }
397
setHeader(const string & key,const string & value)398 void THeaderTransport::setHeader(const string& key, const string& value) {
399 writeHeaders_[key] = value;
400 }
401
getMaxWriteHeadersSize() const402 uint32_t THeaderTransport::getMaxWriteHeadersSize() const {
403 size_t maxWriteHeadersSize = 0;
404 THeaderTransport::StringToStringMap::const_iterator it;
405 for (it = writeHeaders_.begin(); it != writeHeaders_.end(); ++it) {
406 // add sizes of key and value to maxWriteHeadersSize
407 // 2 varints32 + the strings themselves
408 maxWriteHeadersSize += 5 + 5 + (it->first).length() + (it->second).length();
409 }
410 return safe_numeric_cast<uint32_t>(maxWriteHeadersSize);
411 }
412
clearHeaders()413 void THeaderTransport::clearHeaders() {
414 writeHeaders_.clear();
415 }
416
flush()417 void THeaderTransport::flush() {
418 resetConsumedMessageSize();
419 // Write out any data waiting in the write buffer.
420 uint32_t haveBytes = getWriteBytes();
421
422 if (clientType == THRIFT_HEADER_CLIENT_TYPE) {
423 transform(wBuf_.get(), haveBytes);
424 haveBytes = getWriteBytes(); // transform may have changed the size
425 }
426
427 // Note that we reset wBase_ prior to the underlying write
428 // to ensure we're in a sane state (i.e. internal buffer cleaned)
429 // if the underlying write throws up an exception
430 wBase_ = wBuf_.get();
431
432 if (haveBytes > MAX_FRAME_SIZE) {
433 throw TTransportException(TTransportException::CORRUPTED_DATA,
434 "Attempting to send frame that is too large");
435 }
436
437 if (clientType == THRIFT_HEADER_CLIENT_TYPE) {
438 // header size will need to be updated at the end because of varints.
439 // Make it big enough here for max varint size, plus 4 for padding.
440 uint32_t headerSize = (2 + getNumTransforms()) * THRIFT_MAX_VARINT32_BYTES + 4;
441 // add approximate size of info headers
442 headerSize += getMaxWriteHeadersSize();
443
444 // Pkt size
445 uint32_t maxSzHbo = headerSize + haveBytes // thrift header + payload
446 + 10; // common header section
447 uint8_t* pkt = tBuf_.get();
448 uint8_t* headerStart;
449 uint8_t* headerSizePtr;
450 uint8_t* pktStart = pkt;
451
452 if (maxSzHbo > tBufSize_) {
453 throw TTransportException(TTransportException::CORRUPTED_DATA,
454 "Attempting to header frame that is too large");
455 }
456
457 uint32_t szHbo;
458 uint32_t szNbo;
459 uint16_t headerSizeN;
460
461 // Fixup szHbo later
462 pkt += sizeof(szNbo);
463 uint16_t headerN = htons(HEADER_MAGIC >> 16);
464 memcpy(pkt, &headerN, sizeof(headerN));
465 pkt += sizeof(headerN);
466 uint16_t flagsN = htons(flags);
467 memcpy(pkt, &flagsN, sizeof(flagsN));
468 pkt += sizeof(flagsN);
469 uint32_t seqIdN = htonl(seqId);
470 memcpy(pkt, &seqIdN, sizeof(seqIdN));
471 pkt += sizeof(seqIdN);
472 headerSizePtr = pkt;
473 // Fixup headerSizeN later
474 pkt += sizeof(headerSizeN);
475 headerStart = pkt;
476
477 pkt += writeVarint32(protoId, pkt);
478 pkt += writeVarint32(getNumTransforms(), pkt);
479
480 // For now, each transform is only the ID, no following data.
481 for (vector<uint16_t>::const_iterator it = writeTrans_.begin(); it != writeTrans_.end(); ++it) {
482 pkt += writeVarint32(*it, pkt);
483 }
484
485 // write info headers
486
487 // for now only write kv-headers
488 auto headerCount = safe_numeric_cast<int32_t>(writeHeaders_.size());
489 if (headerCount > 0) {
490 pkt += writeVarint32(infoIdType::KEYVALUE, pkt);
491 // Write key-value headers count
492 pkt += writeVarint32(static_cast<int32_t>(headerCount), pkt);
493 // Write info headers
494 map<string, string>::const_iterator it;
495 for (it = writeHeaders_.begin(); it != writeHeaders_.end(); ++it) {
496 writeString(pkt, it->first); // key
497 writeString(pkt, it->second); // value
498 }
499 writeHeaders_.clear();
500 }
501
502 // Fixups after varint size calculations
503 headerSize = safe_numeric_cast<uint32_t>(pkt - headerStart);
504 uint8_t padding = 4 - (headerSize % 4);
505 headerSize += padding;
506
507 // Pad out pkt with 0x00
508 for (int i = 0; i < padding; i++) {
509 *(pkt++) = 0x00;
510 }
511
512 // Pkt size
513 ptrdiff_t szHbp = (headerStart - pktStart - 4);
514 if (static_cast<uint64_t>(szHbp) > static_cast<uint64_t>((std::numeric_limits<uint32_t>().max)()) - (headerSize + haveBytes)) {
515 throw TTransportException(TTransportException::CORRUPTED_DATA,
516 "Header section size is unreasonable");
517 }
518 szHbo = headerSize + haveBytes // thrift header + payload
519 + static_cast<uint32_t>(szHbp); // common header section
520 headerSizeN = htons(headerSize / 4);
521 memcpy(headerSizePtr, &headerSizeN, sizeof(headerSizeN));
522
523 // Set framing size.
524 szNbo = htonl(szHbo);
525 memcpy(pktStart, &szNbo, sizeof(szNbo));
526
527 outTransport_->write(pktStart, szHbo - haveBytes + 4);
528 outTransport_->write(wBuf_.get(), haveBytes);
529 } else if (clientType == THRIFT_FRAMED_BINARY || clientType == THRIFT_FRAMED_COMPACT) {
530 auto szHbo = (uint32_t)haveBytes;
531 uint32_t szNbo = htonl(szHbo);
532
533 outTransport_->write(reinterpret_cast<uint8_t*>(&szNbo), 4);
534 outTransport_->write(wBuf_.get(), haveBytes);
535 } else if (clientType == THRIFT_UNFRAMED_BINARY || clientType == THRIFT_UNFRAMED_COMPACT) {
536 outTransport_->write(wBuf_.get(), haveBytes);
537 } else {
538 throw TTransportException(TTransportException::BAD_ARGS, "Unknown client type");
539 }
540
541 // Flush the underlying transport.
542 outTransport_->flush();
543 }
544
545 /**
546 * Read an i16 from the wire as a varint. The MSB of each byte is set
547 * if there is another byte to follow. This can read up to 3 bytes.
548 */
readVarint16(uint8_t const * ptr,int16_t * i16,uint8_t const * boundary)549 uint32_t THeaderTransport::readVarint16(uint8_t const* ptr, int16_t* i16, uint8_t const* boundary) {
550 int32_t val;
551 uint32_t rsize = readVarint32(ptr, &val, boundary);
552 *i16 = (int16_t)val;
553 return rsize;
554 }
555
556 /**
557 * Read an i32 from the wire as a varint. The MSB of each byte is set
558 * if there is another byte to follow. This can read up to 5 bytes.
559 */
readVarint32(uint8_t const * ptr,int32_t * i32,uint8_t const * boundary)560 uint32_t THeaderTransport::readVarint32(uint8_t const* ptr, int32_t* i32, uint8_t const* boundary) {
561
562 uint32_t rsize = 0;
563 uint32_t val = 0;
564 int shift = 0;
565
566 while (true) {
567 if (ptr == boundary) {
568 throw TApplicationException(TApplicationException::INVALID_MESSAGE_TYPE,
569 "Trying to read past header boundary");
570 }
571 uint8_t byte = *(ptr++);
572 rsize++;
573 val |= (uint64_t)(byte & 0x7f) << shift;
574 shift += 7;
575 if (!(byte & 0x80)) {
576 *i32 = val;
577 return rsize;
578 }
579 }
580 }
581
582 /**
583 * Write an i32 as a varint. Results in 1-5 bytes on the wire.
584 */
writeVarint32(int32_t n,uint8_t * pkt)585 uint32_t THeaderTransport::writeVarint32(int32_t n, uint8_t* pkt) {
586 uint8_t buf[5];
587 uint32_t wsize = 0;
588
589 while (true) {
590 if ((n & ~0x7F) == 0) {
591 buf[wsize++] = (int8_t)n;
592 break;
593 } else {
594 buf[wsize++] = (int8_t)((n & 0x7F) | 0x80);
595 n >>= 7;
596 }
597 }
598
599 // Caller will advance pkt.
600 for (uint32_t i = 0; i < wsize; i++) {
601 pkt[i] = buf[i];
602 }
603
604 return wsize;
605 }
606
writeVarint16(int16_t n,uint8_t * pkt)607 uint32_t THeaderTransport::writeVarint16(int16_t n, uint8_t* pkt) {
608 return writeVarint32(n, pkt);
609 }
610 }
611 }
612 } // apache::thrift::transport
613