1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_
21 #define _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_ 1
22 
23 #include <cstdlib>
24 #include <cstring>
25 #include <string>
26 #include <algorithm>
27 #include <thrift/transport/TTransport.h>
28 // Include the buffered transports that used to be defined here.
29 #include <thrift/transport/TBufferTransports.h>
30 #include <thrift/transport/TFileTransport.h>
31 
32 namespace apache {
33 namespace thrift {
34 namespace transport {
35 
36 /**
37  * The null transport is a dummy transport that doesn't actually do anything.
38  * It's sort of an analogy to /dev/null, you can never read anything from it
39  * and it will let you write anything you want to it, though it won't actually
40  * go anywhere.
41  *
42  */
43 class TNullTransport : public TVirtualTransport<TNullTransport> {
44 public:
45   TNullTransport() = default;
46 
47   ~TNullTransport() override = default;
48 
isOpen()49   bool isOpen() const override { return true; }
50 
open()51   void open() override {}
52 
write(const uint8_t *,uint32_t)53   void write(const uint8_t* /* buf */, uint32_t /* len */) { return; }
54 };
55 
56 /**
57  * TPipedTransport. This transport allows piping of a request from one
58  * transport to another either when readEnd() or writeEnd(). The typical
59  * use case for this is to log a request or a reply to disk.
60  * The underlying buffer expands to a keep a copy of the entire
61  * request/response.
62  *
63  */
64 class TPipedTransport : virtual public TTransport {
65 public:
66   TPipedTransport(std::shared_ptr<TTransport> srcTrans, std::shared_ptr<TTransport> dstTrans,
67                  std::shared_ptr<TConfiguration> config = nullptr)
TTransport(config)68     : TTransport(config),
69       srcTrans_(srcTrans),
70       dstTrans_(dstTrans),
71       rBufSize_(512),
72       rPos_(0),
73       rLen_(0),
74       wBufSize_(512),
75       wLen_(0) {
76 
77     // default is to to pipe the request when readEnd() is called
78     pipeOnRead_ = true;
79     pipeOnWrite_ = false;
80 
81     rBuf_ = (uint8_t*)std::malloc(sizeof(uint8_t) * rBufSize_);
82     if (rBuf_ == nullptr) {
83       throw std::bad_alloc();
84     }
85     wBuf_ = (uint8_t*)std::malloc(sizeof(uint8_t) * wBufSize_);
86     if (wBuf_ == nullptr) {
87       throw std::bad_alloc();
88     }
89   }
90 
91   TPipedTransport(std::shared_ptr<TTransport> srcTrans,
92                   std::shared_ptr<TTransport> dstTrans,
93                   uint32_t sz,
94                   std::shared_ptr<TConfiguration> config = nullptr)
TTransport(config)95     : TTransport(config),
96       srcTrans_(srcTrans),
97       dstTrans_(dstTrans),
98       rBufSize_(512),
99       rPos_(0),
100       rLen_(0),
101       wBufSize_(sz),
102       wLen_(0) {
103 
104     rBuf_ = (uint8_t*)std::malloc(sizeof(uint8_t) * rBufSize_);
105     if (rBuf_ == nullptr) {
106       throw std::bad_alloc();
107     }
108     wBuf_ = (uint8_t*)std::malloc(sizeof(uint8_t) * wBufSize_);
109     if (wBuf_ == nullptr) {
110       throw std::bad_alloc();
111     }
112   }
113 
~TPipedTransport()114   ~TPipedTransport() override {
115     std::free(rBuf_);
116     std::free(wBuf_);
117   }
118 
isOpen()119   bool isOpen() const override { return srcTrans_->isOpen(); }
120 
peek()121   bool peek() override {
122     if (rPos_ >= rLen_) {
123       // Double the size of the underlying buffer if it is full
124       if (rLen_ == rBufSize_) {
125         rBufSize_ *= 2;
126         auto * tmpBuf = (uint8_t*)std::realloc(rBuf_, sizeof(uint8_t) * rBufSize_);
127 	if (tmpBuf == nullptr) {
128 	  throw std::bad_alloc();
129 	}
130 	rBuf_ = tmpBuf;
131       }
132 
133       // try to fill up the buffer
134       rLen_ += srcTrans_->read(rBuf_ + rPos_, rBufSize_ - rPos_);
135     }
136     return (rLen_ > rPos_);
137   }
138 
open()139   void open() override { srcTrans_->open(); }
140 
close()141   void close() override { srcTrans_->close(); }
142 
setPipeOnRead(bool pipeVal)143   void setPipeOnRead(bool pipeVal) { pipeOnRead_ = pipeVal; }
144 
setPipeOnWrite(bool pipeVal)145   void setPipeOnWrite(bool pipeVal) { pipeOnWrite_ = pipeVal; }
146 
147   uint32_t read(uint8_t* buf, uint32_t len);
148 
readEnd()149   uint32_t readEnd() override {
150 
151     if (pipeOnRead_) {
152       dstTrans_->write(rBuf_, rPos_);
153       dstTrans_->flush();
154     }
155 
156     srcTrans_->readEnd();
157 
158     // If requests are being pipelined, copy down our read-ahead data,
159     // then reset our state.
160     int read_ahead = rLen_ - rPos_;
161     uint32_t bytes = rPos_;
162     memcpy(rBuf_, rBuf_ + rPos_, read_ahead);
163     rPos_ = 0;
164     rLen_ = read_ahead;
165 
166     return bytes;
167   }
168 
169   void write(const uint8_t* buf, uint32_t len);
170 
writeEnd()171   uint32_t writeEnd() override {
172     if (pipeOnWrite_) {
173       dstTrans_->write(wBuf_, wLen_);
174       dstTrans_->flush();
175     }
176     return wLen_;
177   }
178 
179   void flush() override;
180 
getTargetTransport()181   std::shared_ptr<TTransport> getTargetTransport() { return dstTrans_; }
182 
183   /*
184    * Override TTransport *_virt() functions to invoke our implementations.
185    * We cannot use TVirtualTransport to provide these, since we need to inherit
186    * virtually from TTransport.
187    */
read_virt(uint8_t * buf,uint32_t len)188   uint32_t read_virt(uint8_t* buf, uint32_t len) override { return this->read(buf, len); }
write_virt(const uint8_t * buf,uint32_t len)189   void write_virt(const uint8_t* buf, uint32_t len) override { this->write(buf, len); }
190 
191 protected:
192   std::shared_ptr<TTransport> srcTrans_;
193   std::shared_ptr<TTransport> dstTrans_;
194 
195   uint8_t* rBuf_;
196   uint32_t rBufSize_;
197   uint32_t rPos_;
198   uint32_t rLen_;
199 
200   uint8_t* wBuf_;
201   uint32_t wBufSize_;
202   uint32_t wLen_;
203 
204   bool pipeOnRead_;
205   bool pipeOnWrite_;
206 };
207 
208 /**
209  * Wraps a transport into a pipedTransport instance.
210  *
211  */
212 class TPipedTransportFactory : public TTransportFactory {
213 public:
214   TPipedTransportFactory() = default;
TPipedTransportFactory(std::shared_ptr<TTransport> dstTrans)215   TPipedTransportFactory(std::shared_ptr<TTransport> dstTrans) {
216     initializeTargetTransport(dstTrans);
217   }
218   ~TPipedTransportFactory() override = default;
219 
220   /**
221    * Wraps the base transport into a piped transport.
222    */
getTransport(std::shared_ptr<TTransport> srcTrans)223   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> srcTrans) override {
224     return std::shared_ptr<TTransport>(new TPipedTransport(srcTrans, dstTrans_));
225   }
226 
initializeTargetTransport(std::shared_ptr<TTransport> dstTrans)227   virtual void initializeTargetTransport(std::shared_ptr<TTransport> dstTrans) {
228     if (dstTrans_.get() == nullptr) {
229       dstTrans_ = dstTrans;
230     } else {
231       throw TException("Target transport already initialized");
232     }
233   }
234 
235 protected:
236   std::shared_ptr<TTransport> dstTrans_;
237 };
238 
239 /**
240  * TPipedFileTransport. This is just like a TTransport, except that
241  * it is a templatized class, so that clients who rely on a specific
242  * TTransport can still access the original transport.
243  *
244  */
245 class TPipedFileReaderTransport : public TPipedTransport, public TFileReaderTransport {
246 public:
247   TPipedFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans,
248                             std::shared_ptr<TTransport> dstTrans,
249                             std::shared_ptr<TConfiguration> config = nullptr);
250 
251   ~TPipedFileReaderTransport() override;
252 
253   // TTransport functions
254   bool isOpen() const override;
255   bool peek() override;
256   void open() override;
257   void close() override;
258   uint32_t read(uint8_t* buf, uint32_t len);
259   uint32_t readAll(uint8_t* buf, uint32_t len);
260   uint32_t readEnd() override;
261   void write(const uint8_t* buf, uint32_t len);
262   uint32_t writeEnd() override;
263   void flush() override;
264 
265   // TFileReaderTransport functions
266   int32_t getReadTimeout() override;
267   void setReadTimeout(int32_t readTimeout) override;
268   uint32_t getNumChunks() override;
269   uint32_t getCurChunk() override;
270   void seekToChunk(int32_t chunk) override;
271   void seekToEnd() override;
272 
273   /*
274    * Override TTransport *_virt() functions to invoke our implementations.
275    * We cannot use TVirtualTransport to provide these, since we need to inherit
276    * virtually from TTransport.
277    */
read_virt(uint8_t * buf,uint32_t len)278   uint32_t read_virt(uint8_t* buf, uint32_t len) override { return this->read(buf, len); }
readAll_virt(uint8_t * buf,uint32_t len)279   uint32_t readAll_virt(uint8_t* buf, uint32_t len) override { return this->readAll(buf, len); }
write_virt(const uint8_t * buf,uint32_t len)280   void write_virt(const uint8_t* buf, uint32_t len) override { this->write(buf, len); }
281 
282 protected:
283   // shouldn't be used
284   TPipedFileReaderTransport();
285   std::shared_ptr<TFileReaderTransport> srcTrans_;
286 };
287 
288 /**
289  * Creates a TPipedFileReaderTransport from a filepath and a destination transport
290  *
291  */
292 class TPipedFileReaderTransportFactory : public TPipedTransportFactory {
293 public:
294   TPipedFileReaderTransportFactory() = default;
TPipedFileReaderTransportFactory(std::shared_ptr<TTransport> dstTrans)295   TPipedFileReaderTransportFactory(std::shared_ptr<TTransport> dstTrans)
296     : TPipedTransportFactory(dstTrans) {}
297   ~TPipedFileReaderTransportFactory() override = default;
298 
getTransport(std::shared_ptr<TTransport> srcTrans)299   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> srcTrans) override {
300     std::shared_ptr<TFileReaderTransport> pFileReaderTransport
301         = std::dynamic_pointer_cast<TFileReaderTransport>(srcTrans);
302     if (pFileReaderTransport.get() != nullptr) {
303       return getFileReaderTransport(pFileReaderTransport);
304     } else {
305       return std::shared_ptr<TTransport>();
306     }
307   }
308 
getFileReaderTransport(std::shared_ptr<TFileReaderTransport> srcTrans)309   std::shared_ptr<TFileReaderTransport> getFileReaderTransport(
310       std::shared_ptr<TFileReaderTransport> srcTrans) {
311     return std::shared_ptr<TFileReaderTransport>(
312         new TPipedFileReaderTransport(srcTrans, dstTrans_));
313   }
314 };
315 }
316 }
317 } // apache::thrift::transport
318 
319 #endif // #ifndef _THRIFT_TRANSPORT_TTRANSPORTUTILS_H_
320