1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20from io import BytesIO 21import struct 22 23from zope.interface import implementer, Interface, Attribute 24from twisted.internet.protocol import ServerFactory, ClientFactory, \ 25 connectionDone 26from twisted.internet import defer 27from twisted.internet.threads import deferToThread 28from twisted.protocols import basic 29from twisted.web import server, resource, http 30 31from thrift.transport import TTransport 32 33 34class TMessageSenderTransport(TTransport.TTransportBase): 35 36 def __init__(self): 37 self.__wbuf = BytesIO() 38 39 def write(self, buf): 40 self.__wbuf.write(buf) 41 42 def flush(self): 43 msg = self.__wbuf.getvalue() 44 self.__wbuf = BytesIO() 45 return self.sendMessage(msg) 46 47 def sendMessage(self, message): 48 raise NotImplementedError 49 50 51class TCallbackTransport(TMessageSenderTransport): 52 53 def __init__(self, func): 54 TMessageSenderTransport.__init__(self) 55 self.func = func 56 57 def sendMessage(self, message): 58 return self.func(message) 59 60 61class ThriftClientProtocol(basic.Int32StringReceiver): 62 63 MAX_LENGTH = 2 ** 31 - 1 64 65 def __init__(self, client_class, iprot_factory, oprot_factory=None): 66 self._client_class = client_class 67 self._iprot_factory = iprot_factory 68 if oprot_factory is None: 69 self._oprot_factory = iprot_factory 70 else: 71 self._oprot_factory = oprot_factory 72 73 self.recv_map = {} 74 self.started = defer.Deferred() 75 76 def dispatch(self, msg): 77 self.sendString(msg) 78 79 def connectionMade(self): 80 tmo = TCallbackTransport(self.dispatch) 81 self.client = self._client_class(tmo, self._oprot_factory) 82 self.started.callback(self.client) 83 84 def connectionLost(self, reason=connectionDone): 85 # the called errbacks can add items to our client's _reqs, 86 # so we need to use a tmp, and iterate until no more requests 87 # are added during errbacks 88 if self.client: 89 tex = TTransport.TTransportException( 90 type=TTransport.TTransportException.END_OF_FILE, 91 message='Connection closed (%s)' % reason) 92 while self.client._reqs: 93 _, v = self.client._reqs.popitem() 94 v.errback(tex) 95 del self.client._reqs 96 self.client = None 97 98 def stringReceived(self, frame): 99 tr = TTransport.TMemoryBuffer(frame) 100 iprot = self._iprot_factory.getProtocol(tr) 101 (fname, mtype, rseqid) = iprot.readMessageBegin() 102 103 try: 104 method = self.recv_map[fname] 105 except KeyError: 106 method = getattr(self.client, 'recv_' + fname) 107 self.recv_map[fname] = method 108 109 method(iprot, mtype, rseqid) 110 111 112class ThriftSASLClientProtocol(ThriftClientProtocol): 113 114 START = 1 115 OK = 2 116 BAD = 3 117 ERROR = 4 118 COMPLETE = 5 119 120 MAX_LENGTH = 2 ** 31 - 1 121 122 def __init__(self, client_class, iprot_factory, oprot_factory=None, 123 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs): 124 """ 125 host: the name of the server, from a SASL perspective 126 service: the name of the server's service, from a SASL perspective 127 mechanism: the name of the preferred mechanism to use 128 129 All other kwargs will be passed to the puresasl.client.SASLClient 130 constructor. 131 """ 132 133 from puresasl.client import SASLClient 134 self.SASLCLient = SASLClient 135 136 ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory) 137 138 self._sasl_negotiation_deferred = None 139 self._sasl_negotiation_status = None 140 self.client = None 141 142 if host is not None: 143 self.createSASLClient(host, service, mechanism, **sasl_kwargs) 144 145 def createSASLClient(self, host, service, mechanism, **kwargs): 146 self.sasl = self.SASLClient(host, service, mechanism, **kwargs) 147 148 def dispatch(self, msg): 149 encoded = self.sasl.wrap(msg) 150 len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded)) 151 ThriftClientProtocol.dispatch(self, len_and_encoded) 152 153 @defer.inlineCallbacks 154 def connectionMade(self): 155 self._sendSASLMessage(self.START, self.sasl.mechanism) 156 initial_message = yield deferToThread(self.sasl.process) 157 self._sendSASLMessage(self.OK, initial_message) 158 159 while True: 160 status, challenge = yield self._receiveSASLMessage() 161 if status == self.OK: 162 response = yield deferToThread(self.sasl.process, challenge) 163 self._sendSASLMessage(self.OK, response) 164 elif status == self.COMPLETE: 165 if not self.sasl.complete: 166 msg = "The server erroneously indicated that SASL " \ 167 "negotiation was complete" 168 raise TTransport.TTransportException(msg, message=msg) 169 else: 170 break 171 else: 172 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge) 173 raise TTransport.TTransportException(msg, message=msg) 174 175 self._sasl_negotiation_deferred = None 176 ThriftClientProtocol.connectionMade(self) 177 178 def _sendSASLMessage(self, status, body): 179 if body is None: 180 body = "" 181 header = struct.pack(">BI", status, len(body)) 182 self.transport.write(header + body) 183 184 def _receiveSASLMessage(self): 185 self._sasl_negotiation_deferred = defer.Deferred() 186 self._sasl_negotiation_status = None 187 return self._sasl_negotiation_deferred 188 189 def connectionLost(self, reason=connectionDone): 190 if self.client: 191 ThriftClientProtocol.connectionLost(self, reason) 192 193 def dataReceived(self, data): 194 if self._sasl_negotiation_deferred: 195 # we got a sasl challenge in the format (status, length, challenge) 196 # save the status, let IntNStringReceiver piece the challenge data together 197 self._sasl_negotiation_status, = struct.unpack("B", data[0]) 198 ThriftClientProtocol.dataReceived(self, data[1:]) 199 else: 200 # normal frame, let IntNStringReceiver piece it together 201 ThriftClientProtocol.dataReceived(self, data) 202 203 def stringReceived(self, frame): 204 if self._sasl_negotiation_deferred: 205 # the frame is just a SASL challenge 206 response = (self._sasl_negotiation_status, frame) 207 self._sasl_negotiation_deferred.callback(response) 208 else: 209 # there's a second 4 byte length prefix inside the frame 210 decoded_frame = self.sasl.unwrap(frame[4:]) 211 ThriftClientProtocol.stringReceived(self, decoded_frame) 212 213 214class ThriftServerProtocol(basic.Int32StringReceiver): 215 216 MAX_LENGTH = 2 ** 31 - 1 217 218 def dispatch(self, msg): 219 self.sendString(msg) 220 221 def processError(self, error): 222 self.transport.loseConnection() 223 224 def processOk(self, _, tmo): 225 msg = tmo.getvalue() 226 227 if len(msg) > 0: 228 self.dispatch(msg) 229 230 def stringReceived(self, frame): 231 tmi = TTransport.TMemoryBuffer(frame) 232 tmo = TTransport.TMemoryBuffer() 233 234 iprot = self.factory.iprot_factory.getProtocol(tmi) 235 oprot = self.factory.oprot_factory.getProtocol(tmo) 236 237 d = self.factory.processor.process(iprot, oprot) 238 d.addCallbacks(self.processOk, self.processError, 239 callbackArgs=(tmo,)) 240 241 242class IThriftServerFactory(Interface): 243 244 processor = Attribute("Thrift processor") 245 246 iprot_factory = Attribute("Input protocol factory") 247 248 oprot_factory = Attribute("Output protocol factory") 249 250 251class IThriftClientFactory(Interface): 252 253 client_class = Attribute("Thrift client class") 254 255 iprot_factory = Attribute("Input protocol factory") 256 257 oprot_factory = Attribute("Output protocol factory") 258 259 260@implementer(IThriftServerFactory) 261class ThriftServerFactory(ServerFactory): 262 263 protocol = ThriftServerProtocol 264 265 def __init__(self, processor, iprot_factory, oprot_factory=None): 266 self.processor = processor 267 self.iprot_factory = iprot_factory 268 if oprot_factory is None: 269 self.oprot_factory = iprot_factory 270 else: 271 self.oprot_factory = oprot_factory 272 273 274@implementer(IThriftClientFactory) 275class ThriftClientFactory(ClientFactory): 276 277 protocol = ThriftClientProtocol 278 279 def __init__(self, client_class, iprot_factory, oprot_factory=None): 280 self.client_class = client_class 281 self.iprot_factory = iprot_factory 282 if oprot_factory is None: 283 self.oprot_factory = iprot_factory 284 else: 285 self.oprot_factory = oprot_factory 286 287 def buildProtocol(self, addr): 288 p = self.protocol(self.client_class, self.iprot_factory, 289 self.oprot_factory) 290 p.factory = self 291 return p 292 293 294class ThriftResource(resource.Resource): 295 296 allowedMethods = ('POST',) 297 298 def __init__(self, processor, inputProtocolFactory, 299 outputProtocolFactory=None): 300 resource.Resource.__init__(self) 301 self.inputProtocolFactory = inputProtocolFactory 302 if outputProtocolFactory is None: 303 self.outputProtocolFactory = inputProtocolFactory 304 else: 305 self.outputProtocolFactory = outputProtocolFactory 306 self.processor = processor 307 308 def getChild(self, path, request): 309 return self 310 311 def _cbProcess(self, _, request, tmo): 312 msg = tmo.getvalue() 313 request.setResponseCode(http.OK) 314 request.setHeader("content-type", "application/x-thrift") 315 request.write(msg) 316 request.finish() 317 318 def render_POST(self, request): 319 request.content.seek(0, 0) 320 data = request.content.read() 321 tmi = TTransport.TMemoryBuffer(data) 322 tmo = TTransport.TMemoryBuffer() 323 324 iprot = self.inputProtocolFactory.getProtocol(tmi) 325 oprot = self.outputProtocolFactory.getProtocol(tmo) 326 327 d = self.processor.process(iprot, oprot) 328 d.addCallback(self._cbProcess, request, tmo) 329 return server.NOT_DONE_YET 330