1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from io import BytesIO
21import struct
22
23from zope.interface import implementer, Interface, Attribute
24from twisted.internet.protocol import ServerFactory, ClientFactory, \
25    connectionDone
26from twisted.internet import defer
27from twisted.internet.threads import deferToThread
28from twisted.protocols import basic
29from twisted.web import server, resource, http
30
31from thrift.transport import TTransport
32
33
34class TMessageSenderTransport(TTransport.TTransportBase):
35
36    def __init__(self):
37        self.__wbuf = BytesIO()
38
39    def write(self, buf):
40        self.__wbuf.write(buf)
41
42    def flush(self):
43        msg = self.__wbuf.getvalue()
44        self.__wbuf = BytesIO()
45        return self.sendMessage(msg)
46
47    def sendMessage(self, message):
48        raise NotImplementedError
49
50
51class TCallbackTransport(TMessageSenderTransport):
52
53    def __init__(self, func):
54        TMessageSenderTransport.__init__(self)
55        self.func = func
56
57    def sendMessage(self, message):
58        return self.func(message)
59
60
61class ThriftClientProtocol(basic.Int32StringReceiver):
62
63    MAX_LENGTH = 2 ** 31 - 1
64
65    def __init__(self, client_class, iprot_factory, oprot_factory=None):
66        self._client_class = client_class
67        self._iprot_factory = iprot_factory
68        if oprot_factory is None:
69            self._oprot_factory = iprot_factory
70        else:
71            self._oprot_factory = oprot_factory
72
73        self.recv_map = {}
74        self.started = defer.Deferred()
75
76    def dispatch(self, msg):
77        self.sendString(msg)
78
79    def connectionMade(self):
80        tmo = TCallbackTransport(self.dispatch)
81        self.client = self._client_class(tmo, self._oprot_factory)
82        self.started.callback(self.client)
83
84    def connectionLost(self, reason=connectionDone):
85        # the called errbacks can add items to our client's _reqs,
86        # so we need to use a tmp, and iterate until no more requests
87        # are added during errbacks
88        if self.client:
89            tex = TTransport.TTransportException(
90                type=TTransport.TTransportException.END_OF_FILE,
91                message='Connection closed (%s)' % reason)
92            while self.client._reqs:
93                _, v = self.client._reqs.popitem()
94                v.errback(tex)
95            del self.client._reqs
96            self.client = None
97
98    def stringReceived(self, frame):
99        tr = TTransport.TMemoryBuffer(frame)
100        iprot = self._iprot_factory.getProtocol(tr)
101        (fname, mtype, rseqid) = iprot.readMessageBegin()
102
103        try:
104            method = self.recv_map[fname]
105        except KeyError:
106            method = getattr(self.client, 'recv_' + fname)
107            self.recv_map[fname] = method
108
109        method(iprot, mtype, rseqid)
110
111
112class ThriftSASLClientProtocol(ThriftClientProtocol):
113
114    START = 1
115    OK = 2
116    BAD = 3
117    ERROR = 4
118    COMPLETE = 5
119
120    MAX_LENGTH = 2 ** 31 - 1
121
122    def __init__(self, client_class, iprot_factory, oprot_factory=None,
123                 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
124        """
125        host: the name of the server, from a SASL perspective
126        service: the name of the server's service, from a SASL perspective
127        mechanism: the name of the preferred mechanism to use
128
129        All other kwargs will be passed to the puresasl.client.SASLClient
130        constructor.
131        """
132
133        from puresasl.client import SASLClient
134        self.SASLCLient = SASLClient
135
136        ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
137
138        self._sasl_negotiation_deferred = None
139        self._sasl_negotiation_status = None
140        self.client = None
141
142        if host is not None:
143            self.createSASLClient(host, service, mechanism, **sasl_kwargs)
144
145    def createSASLClient(self, host, service, mechanism, **kwargs):
146        self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
147
148    def dispatch(self, msg):
149        encoded = self.sasl.wrap(msg)
150        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
151        ThriftClientProtocol.dispatch(self, len_and_encoded)
152
153    @defer.inlineCallbacks
154    def connectionMade(self):
155        self._sendSASLMessage(self.START, self.sasl.mechanism)
156        initial_message = yield deferToThread(self.sasl.process)
157        self._sendSASLMessage(self.OK, initial_message)
158
159        while True:
160            status, challenge = yield self._receiveSASLMessage()
161            if status == self.OK:
162                response = yield deferToThread(self.sasl.process, challenge)
163                self._sendSASLMessage(self.OK, response)
164            elif status == self.COMPLETE:
165                if not self.sasl.complete:
166                    msg = "The server erroneously indicated that SASL " \
167                          "negotiation was complete"
168                    raise TTransport.TTransportException(msg, message=msg)
169                else:
170                    break
171            else:
172                msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
173                raise TTransport.TTransportException(msg, message=msg)
174
175        self._sasl_negotiation_deferred = None
176        ThriftClientProtocol.connectionMade(self)
177
178    def _sendSASLMessage(self, status, body):
179        if body is None:
180            body = ""
181        header = struct.pack(">BI", status, len(body))
182        self.transport.write(header + body)
183
184    def _receiveSASLMessage(self):
185        self._sasl_negotiation_deferred = defer.Deferred()
186        self._sasl_negotiation_status = None
187        return self._sasl_negotiation_deferred
188
189    def connectionLost(self, reason=connectionDone):
190        if self.client:
191            ThriftClientProtocol.connectionLost(self, reason)
192
193    def dataReceived(self, data):
194        if self._sasl_negotiation_deferred:
195            # we got a sasl challenge in the format (status, length, challenge)
196            # save the status, let IntNStringReceiver piece the challenge data together
197            self._sasl_negotiation_status, = struct.unpack("B", data[0])
198            ThriftClientProtocol.dataReceived(self, data[1:])
199        else:
200            # normal frame, let IntNStringReceiver piece it together
201            ThriftClientProtocol.dataReceived(self, data)
202
203    def stringReceived(self, frame):
204        if self._sasl_negotiation_deferred:
205            # the frame is just a SASL challenge
206            response = (self._sasl_negotiation_status, frame)
207            self._sasl_negotiation_deferred.callback(response)
208        else:
209            # there's a second 4 byte length prefix inside the frame
210            decoded_frame = self.sasl.unwrap(frame[4:])
211            ThriftClientProtocol.stringReceived(self, decoded_frame)
212
213
214class ThriftServerProtocol(basic.Int32StringReceiver):
215
216    MAX_LENGTH = 2 ** 31 - 1
217
218    def dispatch(self, msg):
219        self.sendString(msg)
220
221    def processError(self, error):
222        self.transport.loseConnection()
223
224    def processOk(self, _, tmo):
225        msg = tmo.getvalue()
226
227        if len(msg) > 0:
228            self.dispatch(msg)
229
230    def stringReceived(self, frame):
231        tmi = TTransport.TMemoryBuffer(frame)
232        tmo = TTransport.TMemoryBuffer()
233
234        iprot = self.factory.iprot_factory.getProtocol(tmi)
235        oprot = self.factory.oprot_factory.getProtocol(tmo)
236
237        d = self.factory.processor.process(iprot, oprot)
238        d.addCallbacks(self.processOk, self.processError,
239                       callbackArgs=(tmo,))
240
241
242class IThriftServerFactory(Interface):
243
244    processor = Attribute("Thrift processor")
245
246    iprot_factory = Attribute("Input protocol factory")
247
248    oprot_factory = Attribute("Output protocol factory")
249
250
251class IThriftClientFactory(Interface):
252
253    client_class = Attribute("Thrift client class")
254
255    iprot_factory = Attribute("Input protocol factory")
256
257    oprot_factory = Attribute("Output protocol factory")
258
259
260@implementer(IThriftServerFactory)
261class ThriftServerFactory(ServerFactory):
262
263    protocol = ThriftServerProtocol
264
265    def __init__(self, processor, iprot_factory, oprot_factory=None):
266        self.processor = processor
267        self.iprot_factory = iprot_factory
268        if oprot_factory is None:
269            self.oprot_factory = iprot_factory
270        else:
271            self.oprot_factory = oprot_factory
272
273
274@implementer(IThriftClientFactory)
275class ThriftClientFactory(ClientFactory):
276
277    protocol = ThriftClientProtocol
278
279    def __init__(self, client_class, iprot_factory, oprot_factory=None):
280        self.client_class = client_class
281        self.iprot_factory = iprot_factory
282        if oprot_factory is None:
283            self.oprot_factory = iprot_factory
284        else:
285            self.oprot_factory = oprot_factory
286
287    def buildProtocol(self, addr):
288        p = self.protocol(self.client_class, self.iprot_factory,
289                          self.oprot_factory)
290        p.factory = self
291        return p
292
293
294class ThriftResource(resource.Resource):
295
296    allowedMethods = ('POST',)
297
298    def __init__(self, processor, inputProtocolFactory,
299                 outputProtocolFactory=None):
300        resource.Resource.__init__(self)
301        self.inputProtocolFactory = inputProtocolFactory
302        if outputProtocolFactory is None:
303            self.outputProtocolFactory = inputProtocolFactory
304        else:
305            self.outputProtocolFactory = outputProtocolFactory
306        self.processor = processor
307
308    def getChild(self, path, request):
309        return self
310
311    def _cbProcess(self, _, request, tmo):
312        msg = tmo.getvalue()
313        request.setResponseCode(http.OK)
314        request.setHeader("content-type", "application/x-thrift")
315        request.write(msg)
316        request.finish()
317
318    def render_POST(self, request):
319        request.content.seek(0, 0)
320        data = request.content.read()
321        tmi = TTransport.TMemoryBuffer(data)
322        tmo = TTransport.TMemoryBuffer()
323
324        iprot = self.inputProtocolFactory.getProtocol(tmi)
325        oprot = self.outputProtocolFactory.getProtocol(tmo)
326
327        d = self.processor.process(iprot, oprot)
328        d.addCallback(self._cbProcess, request, tmo)
329        return server.NOT_DONE_YET
330