1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20from struct import pack, unpack 21from thrift.Thrift import TException 22from ..compat import BufferIO 23 24 25class TTransportException(TException): 26 """Custom Transport Exception class""" 27 28 UNKNOWN = 0 29 NOT_OPEN = 1 30 ALREADY_OPEN = 2 31 TIMED_OUT = 3 32 END_OF_FILE = 4 33 NEGATIVE_SIZE = 5 34 SIZE_LIMIT = 6 35 INVALID_CLIENT_TYPE = 7 36 37 def __init__(self, type=UNKNOWN, message=None, inner=None): 38 TException.__init__(self, message) 39 self.type = type 40 self.inner = inner 41 42 43class TTransportBase(object): 44 """Base class for Thrift transport layer.""" 45 46 def isOpen(self): 47 pass 48 49 def open(self): 50 pass 51 52 def close(self): 53 pass 54 55 def read(self, sz): 56 pass 57 58 def readAll(self, sz): 59 buff = b'' 60 have = 0 61 while (have < sz): 62 chunk = self.read(sz - have) 63 chunkLen = len(chunk) 64 have += chunkLen 65 buff += chunk 66 67 if chunkLen == 0: 68 raise EOFError() 69 70 return buff 71 72 def write(self, buf): 73 pass 74 75 def flush(self): 76 pass 77 78 79# This class should be thought of as an interface. 80class CReadableTransport(object): 81 """base class for transports that are readable from C""" 82 83 # TODO(dreiss): Think about changing this interface to allow us to use 84 # a (Python, not c) StringIO instead, because it allows 85 # you to write after reading. 86 87 # NOTE: This is a classic class, so properties will NOT work 88 # correctly for setting. 89 @property 90 def cstringio_buf(self): 91 """A cStringIO buffer that contains the current chunk we are reading.""" 92 pass 93 94 def cstringio_refill(self, partialread, reqlen): 95 """Refills cstringio_buf. 96 97 Returns the currently used buffer (which can but need not be the same as 98 the old cstringio_buf). partialread is what the C code has read from the 99 buffer, and should be inserted into the buffer before any more reads. The 100 return value must be a new, not borrowed reference. Something along the 101 lines of self._buf should be fine. 102 103 If reqlen bytes can't be read, throw EOFError. 104 """ 105 pass 106 107 108class TServerTransportBase(object): 109 """Base class for Thrift server transports.""" 110 111 def listen(self): 112 pass 113 114 def accept(self): 115 pass 116 117 def close(self): 118 pass 119 120 121class TTransportFactoryBase(object): 122 """Base class for a Transport Factory""" 123 124 def getTransport(self, trans): 125 return trans 126 127 128class TBufferedTransportFactory(object): 129 """Factory transport that builds buffered transports""" 130 131 def getTransport(self, trans): 132 buffered = TBufferedTransport(trans) 133 return buffered 134 135 136class TBufferedTransport(TTransportBase, CReadableTransport): 137 """Class that wraps another transport and buffers its I/O. 138 139 The implementation uses a (configurable) fixed-size read buffer 140 but buffers all writes until a flush is performed. 141 """ 142 DEFAULT_BUFFER = 4096 143 144 def __init__(self, trans, rbuf_size=DEFAULT_BUFFER): 145 self.__trans = trans 146 self.__wbuf = BufferIO() 147 # Pass string argument to initialize read buffer as cStringIO.InputType 148 self.__rbuf = BufferIO(b'') 149 self.__rbuf_size = rbuf_size 150 151 def isOpen(self): 152 return self.__trans.isOpen() 153 154 def open(self): 155 return self.__trans.open() 156 157 def close(self): 158 return self.__trans.close() 159 160 def read(self, sz): 161 ret = self.__rbuf.read(sz) 162 if len(ret) != 0: 163 return ret 164 self.__rbuf = BufferIO(self.__trans.read(max(sz, self.__rbuf_size))) 165 return self.__rbuf.read(sz) 166 167 def write(self, buf): 168 try: 169 self.__wbuf.write(buf) 170 except Exception as e: 171 # on exception reset wbuf so it doesn't contain a partial function call 172 self.__wbuf = BufferIO() 173 raise e 174 175 def flush(self): 176 out = self.__wbuf.getvalue() 177 # reset wbuf before write/flush to preserve state on underlying failure 178 self.__wbuf = BufferIO() 179 self.__trans.write(out) 180 self.__trans.flush() 181 182 # Implement the CReadableTransport interface. 183 @property 184 def cstringio_buf(self): 185 return self.__rbuf 186 187 def cstringio_refill(self, partialread, reqlen): 188 retstring = partialread 189 if reqlen < self.__rbuf_size: 190 # try to make a read of as much as we can. 191 retstring += self.__trans.read(self.__rbuf_size) 192 193 # but make sure we do read reqlen bytes. 194 if len(retstring) < reqlen: 195 retstring += self.__trans.readAll(reqlen - len(retstring)) 196 197 self.__rbuf = BufferIO(retstring) 198 return self.__rbuf 199 200 201class TMemoryBuffer(TTransportBase, CReadableTransport): 202 """Wraps a cBytesIO object as a TTransport. 203 204 NOTE: Unlike the C++ version of this class, you cannot write to it 205 then immediately read from it. If you want to read from a 206 TMemoryBuffer, you must either pass a string to the constructor. 207 TODO(dreiss): Make this work like the C++ version. 208 """ 209 210 def __init__(self, value=None, offset=0): 211 """value -- a value to read from for stringio 212 213 If value is set, this will be a transport for reading, 214 otherwise, it is for writing""" 215 if value is not None: 216 self._buffer = BufferIO(value) 217 else: 218 self._buffer = BufferIO() 219 if offset: 220 self._buffer.seek(offset) 221 222 def isOpen(self): 223 return not self._buffer.closed 224 225 def open(self): 226 pass 227 228 def close(self): 229 self._buffer.close() 230 231 def read(self, sz): 232 return self._buffer.read(sz) 233 234 def write(self, buf): 235 self._buffer.write(buf) 236 237 def flush(self): 238 pass 239 240 def getvalue(self): 241 return self._buffer.getvalue() 242 243 # Implement the CReadableTransport interface. 244 @property 245 def cstringio_buf(self): 246 return self._buffer 247 248 def cstringio_refill(self, partialread, reqlen): 249 # only one shot at reading... 250 raise EOFError() 251 252 253class TFramedTransportFactory(object): 254 """Factory transport that builds framed transports""" 255 256 def getTransport(self, trans): 257 framed = TFramedTransport(trans) 258 return framed 259 260 261class TFramedTransport(TTransportBase, CReadableTransport): 262 """Class that wraps another transport and frames its I/O when writing.""" 263 264 def __init__(self, trans,): 265 self.__trans = trans 266 self.__rbuf = BufferIO(b'') 267 self.__wbuf = BufferIO() 268 269 def isOpen(self): 270 return self.__trans.isOpen() 271 272 def open(self): 273 return self.__trans.open() 274 275 def close(self): 276 return self.__trans.close() 277 278 def read(self, sz): 279 ret = self.__rbuf.read(sz) 280 if len(ret) != 0: 281 return ret 282 283 self.readFrame() 284 return self.__rbuf.read(sz) 285 286 def readFrame(self): 287 buff = self.__trans.readAll(4) 288 sz, = unpack('!i', buff) 289 self.__rbuf = BufferIO(self.__trans.readAll(sz)) 290 291 def write(self, buf): 292 self.__wbuf.write(buf) 293 294 def flush(self): 295 wout = self.__wbuf.getvalue() 296 wsz = len(wout) 297 # reset wbuf before write/flush to preserve state on underlying failure 298 self.__wbuf = BufferIO() 299 # N.B.: Doing this string concatenation is WAY cheaper than making 300 # two separate calls to the underlying socket object. Socket writes in 301 # Python turn out to be REALLY expensive, but it seems to do a pretty 302 # good job of managing string buffer operations without excessive copies 303 buf = pack("!i", wsz) + wout 304 self.__trans.write(buf) 305 self.__trans.flush() 306 307 # Implement the CReadableTransport interface. 308 @property 309 def cstringio_buf(self): 310 return self.__rbuf 311 312 def cstringio_refill(self, prefix, reqlen): 313 # self.__rbuf will already be empty here because fastbinary doesn't 314 # ask for a refill until the previous buffer is empty. Therefore, 315 # we can start reading new frames immediately. 316 while len(prefix) < reqlen: 317 self.readFrame() 318 prefix += self.__rbuf.getvalue() 319 self.__rbuf = BufferIO(prefix) 320 return self.__rbuf 321 322 323class TFileObjectTransport(TTransportBase): 324 """Wraps a file-like object to make it work as a Thrift transport.""" 325 326 def __init__(self, fileobj): 327 self.fileobj = fileobj 328 329 def isOpen(self): 330 return True 331 332 def close(self): 333 self.fileobj.close() 334 335 def read(self, sz): 336 return self.fileobj.read(sz) 337 338 def write(self, buf): 339 self.fileobj.write(buf) 340 341 def flush(self): 342 self.fileobj.flush() 343 344 345class TSaslClientTransport(TTransportBase, CReadableTransport): 346 """ 347 SASL transport 348 """ 349 350 START = 1 351 OK = 2 352 BAD = 3 353 ERROR = 4 354 COMPLETE = 5 355 356 def __init__(self, transport, host, service, mechanism='GSSAPI', 357 **sasl_kwargs): 358 """ 359 transport: an underlying transport to use, typically just a TSocket 360 host: the name of the server, from a SASL perspective 361 service: the name of the server's service, from a SASL perspective 362 mechanism: the name of the preferred mechanism to use 363 364 All other kwargs will be passed to the puresasl.client.SASLClient 365 constructor. 366 """ 367 368 from puresasl.client import SASLClient 369 370 self.transport = transport 371 self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs) 372 373 self.__wbuf = BufferIO() 374 self.__rbuf = BufferIO(b'') 375 376 def open(self): 377 if not self.transport.isOpen(): 378 self.transport.open() 379 380 self.send_sasl_msg(self.START, bytes(self.sasl.mechanism, 'ascii')) 381 self.send_sasl_msg(self.OK, self.sasl.process()) 382 383 while True: 384 status, challenge = self.recv_sasl_msg() 385 if status == self.OK: 386 self.send_sasl_msg(self.OK, self.sasl.process(challenge)) 387 elif status == self.COMPLETE: 388 if not self.sasl.complete: 389 raise TTransportException( 390 TTransportException.NOT_OPEN, 391 "The server erroneously indicated " 392 "that SASL negotiation was complete") 393 else: 394 break 395 else: 396 raise TTransportException( 397 TTransportException.NOT_OPEN, 398 "Bad SASL negotiation status: %d (%s)" 399 % (status, challenge)) 400 401 def isOpen(self): 402 return self.transport.isOpen() 403 404 def send_sasl_msg(self, status, body): 405 header = pack(">BI", status, len(body)) 406 self.transport.write(header + body) 407 self.transport.flush() 408 409 def recv_sasl_msg(self): 410 header = self.transport.readAll(5) 411 status, length = unpack(">BI", header) 412 if length > 0: 413 payload = self.transport.readAll(length) 414 else: 415 payload = "" 416 return status, payload 417 418 def write(self, data): 419 self.__wbuf.write(data) 420 421 def flush(self): 422 data = self.__wbuf.getvalue() 423 encoded = self.sasl.wrap(data) 424 self.transport.write(pack("!i", len(encoded)) + encoded) 425 self.transport.flush() 426 self.__wbuf = BufferIO() 427 428 def read(self, sz): 429 ret = self.__rbuf.read(sz) 430 if len(ret) != 0: 431 return ret 432 433 self._read_frame() 434 return self.__rbuf.read(sz) 435 436 def _read_frame(self): 437 header = self.transport.readAll(4) 438 length, = unpack('!i', header) 439 encoded = self.transport.readAll(length) 440 self.__rbuf = BufferIO(self.sasl.unwrap(encoded)) 441 442 def close(self): 443 self.sasl.dispose() 444 self.transport.close() 445 446 # based on TFramedTransport 447 @property 448 def cstringio_buf(self): 449 return self.__rbuf 450 451 def cstringio_refill(self, prefix, reqlen): 452 # self.__rbuf will already be empty here because fastbinary doesn't 453 # ask for a refill until the previous buffer is empty. Therefore, 454 # we can start reading new frames immediately. 455 while len(prefix) < reqlen: 456 self._read_frame() 457 prefix += self.__rbuf.getvalue() 458 self.__rbuf = BufferIO(prefix) 459 return self.__rbuf 460