1#!/usr/bin/env python 2 3# 4# Licensed to the Apache Software Foundation (ASF) under one 5# or more contributor license agreements. See the NOTICE file 6# distributed with this work for additional information 7# regarding copyright ownership. The ASF licenses this file 8# to you under the Apache License, Version 2.0 (the 9# "License"); you may not use this file except in compliance 10# with the License. You may obtain a copy of the License at 11# 12# http://www.apache.org/licenses/LICENSE-2.0 13# 14# Unless required by applicable law or agreed to in writing, 15# software distributed under the License is distributed on an 16# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 17# KIND, either express or implied. See the License for the 18# specific language governing permissions and limitations 19# under the License. 20# 21from __future__ import division 22import logging 23import os 24import signal 25import sys 26import time 27from optparse import OptionParser 28 29from util import local_libpath 30sys.path.insert(0, local_libpath()) 31from thrift.protocol import TProtocol, TProtocolDecorator 32 33SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) 34 35 36class TestHandler(object): 37 def testVoid(self): 38 if options.verbose > 1: 39 logging.info('testVoid()') 40 41 def testString(self, str): 42 if options.verbose > 1: 43 logging.info('testString(%s)' % str) 44 return str 45 46 def testBool(self, boolean): 47 if options.verbose > 1: 48 logging.info('testBool(%s)' % str(boolean).lower()) 49 return boolean 50 51 def testByte(self, byte): 52 if options.verbose > 1: 53 logging.info('testByte(%d)' % byte) 54 return byte 55 56 def testI16(self, i16): 57 if options.verbose > 1: 58 logging.info('testI16(%d)' % i16) 59 return i16 60 61 def testI32(self, i32): 62 if options.verbose > 1: 63 logging.info('testI32(%d)' % i32) 64 return i32 65 66 def testI64(self, i64): 67 if options.verbose > 1: 68 logging.info('testI64(%d)' % i64) 69 return i64 70 71 def testDouble(self, dub): 72 if options.verbose > 1: 73 logging.info('testDouble(%f)' % dub) 74 return dub 75 76 def testBinary(self, thing): 77 if options.verbose > 1: 78 logging.info('testBinary()') # TODO: hex output 79 return thing 80 81 def testStruct(self, thing): 82 if options.verbose > 1: 83 logging.info('testStruct({%s, %s, %s, %s})' % (thing.string_thing, thing.byte_thing, thing.i32_thing, thing.i64_thing)) 84 return thing 85 86 def testException(self, arg): 87 # if options.verbose > 1: 88 logging.info('testException(%s)' % arg) 89 if arg == 'Xception': 90 raise Xception(errorCode=1001, message=arg) 91 elif arg == 'TException': 92 raise TException(message='This is a TException') 93 94 def testMultiException(self, arg0, arg1): 95 if options.verbose > 1: 96 logging.info('testMultiException(%s, %s)' % (arg0, arg1)) 97 if arg0 == 'Xception': 98 raise Xception(errorCode=1001, message='This is an Xception') 99 elif arg0 == 'Xception2': 100 raise Xception2( 101 errorCode=2002, 102 struct_thing=Xtruct(string_thing='This is an Xception2')) 103 return Xtruct(string_thing=arg1) 104 105 def testOneway(self, seconds): 106 if options.verbose > 1: 107 logging.info('testOneway(%d) => sleeping...' % seconds) 108 time.sleep(seconds / 3) # be quick 109 if options.verbose > 1: 110 logging.info('done sleeping') 111 112 def testNest(self, thing): 113 if options.verbose > 1: 114 logging.info('testNest(%s)' % thing) 115 return thing 116 117 def testMap(self, thing): 118 if options.verbose > 1: 119 logging.info('testMap(%s)' % thing) 120 return thing 121 122 def testStringMap(self, thing): 123 if options.verbose > 1: 124 logging.info('testStringMap(%s)' % thing) 125 return thing 126 127 def testSet(self, thing): 128 if options.verbose > 1: 129 logging.info('testSet(%s)' % thing) 130 return thing 131 132 def testList(self, thing): 133 if options.verbose > 1: 134 logging.info('testList(%s)' % thing) 135 return thing 136 137 def testEnum(self, thing): 138 if options.verbose > 1: 139 logging.info('testEnum(%s)' % thing) 140 return thing 141 142 def testTypedef(self, thing): 143 if options.verbose > 1: 144 logging.info('testTypedef(%s)' % thing) 145 return thing 146 147 def testMapMap(self, thing): 148 if options.verbose > 1: 149 logging.info('testMapMap(%s)' % thing) 150 return { 151 -4: { 152 -4: -4, 153 -3: -3, 154 -2: -2, 155 -1: -1, 156 }, 157 4: { 158 4: 4, 159 3: 3, 160 2: 2, 161 1: 1, 162 }, 163 } 164 165 def testInsanity(self, argument): 166 if options.verbose > 1: 167 logging.info('testInsanity(%s)' % argument) 168 return { 169 1: { 170 2: argument, 171 3: argument, 172 }, 173 2: {6: Insanity()}, 174 } 175 176 def testMulti(self, arg0, arg1, arg2, arg3, arg4, arg5): 177 if options.verbose > 1: 178 logging.info('testMulti(%s, %s, %s, %s, %s, %s)' % (arg0, arg1, arg2, arg3, arg4, arg5)) 179 return Xtruct(string_thing='Hello2', 180 byte_thing=arg0, i32_thing=arg1, i64_thing=arg2) 181 182 183class SecondHandler(object): 184 def secondtestString(self, argument): 185 return "testString(\"" + argument + "\")" 186 187 188# LAST_SEQID is a global because we have one transport and multiple protocols 189# running on it (when multiplexed) 190LAST_SEQID = None 191 192 193class TPedanticSequenceIdProtocolWrapper(TProtocolDecorator.TProtocolDecorator): 194 """ 195 Wraps any protocol with sequence ID checking: looks for outbound 196 uniqueness as well as request/response alignment. 197 """ 198 def __init__(self, protocol): 199 # TProtocolDecorator.__new__ does all the heavy lifting 200 pass 201 202 def readMessageBegin(self): 203 global LAST_SEQID 204 (name, type, seqid) =\ 205 super(TPedanticSequenceIdProtocolWrapper, self).readMessageBegin() 206 if LAST_SEQID is not None and LAST_SEQID == seqid: 207 raise TProtocol.TProtocolException( 208 TProtocol.TProtocolException.INVALID_DATA, 209 "We received the same seqid {0} twice in a row".format(seqid)) 210 LAST_SEQID = seqid 211 return (name, type, seqid) 212 213 214def make_pedantic(proto): 215 """ Wrap a protocol in the pedantic sequence ID wrapper. """ 216 # NOTE: this is disabled for now as many clients send seqid 217 # of zero and that is okay, need a way to identify 218 # clients that MUST send seqid unique to function right 219 # or just force all implementations to send unique seqids (preferred) 220 return proto # TPedanticSequenceIdProtocolWrapper(proto) 221 222 223class TPedanticSequenceIdProtocolFactory(TProtocol.TProtocolFactory): 224 def __init__(self, encapsulated): 225 super(TPedanticSequenceIdProtocolFactory, self).__init__() 226 self.encapsulated = encapsulated 227 228 def getProtocol(self, trans): 229 return make_pedantic(self.encapsulated.getProtocol(trans)) 230 231 232def main(options): 233 # common header allowed client types 234 allowed_client_types = [ 235 THeaderTransport.THeaderClientType.HEADERS, 236 THeaderTransport.THeaderClientType.FRAMED_BINARY, 237 THeaderTransport.THeaderClientType.UNFRAMED_BINARY, 238 THeaderTransport.THeaderClientType.FRAMED_COMPACT, 239 THeaderTransport.THeaderClientType.UNFRAMED_COMPACT, 240 ] 241 242 # set up the protocol factory form the --protocol option 243 prot_factories = { 244 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(), 245 'multia': TBinaryProtocol.TBinaryProtocolAcceleratedFactory(), 246 'accelc': TCompactProtocol.TCompactProtocolAcceleratedFactory(), 247 'multiac': TCompactProtocol.TCompactProtocolAcceleratedFactory(), 248 'binary': TPedanticSequenceIdProtocolFactory(TBinaryProtocol.TBinaryProtocolFactory()), 249 'multi': TPedanticSequenceIdProtocolFactory(TBinaryProtocol.TBinaryProtocolFactory()), 250 'compact': TCompactProtocol.TCompactProtocolFactory(), 251 'multic': TCompactProtocol.TCompactProtocolFactory(), 252 'header': THeaderProtocol.THeaderProtocolFactory(allowed_client_types), 253 'multih': THeaderProtocol.THeaderProtocolFactory(allowed_client_types), 254 'json': TJSONProtocol.TJSONProtocolFactory(), 255 'multij': TJSONProtocol.TJSONProtocolFactory(), 256 } 257 pfactory = prot_factories.get(options.proto, None) 258 if pfactory is None: 259 raise AssertionError('Unknown --protocol option: %s' % options.proto) 260 try: 261 pfactory.string_length_limit = options.string_limit 262 pfactory.container_length_limit = options.container_limit 263 except Exception: 264 # Ignore errors for those protocols that does not support length limit 265 pass 266 267 # get the server type (TSimpleServer, TNonblockingServer, etc...) 268 if len(args) > 1: 269 raise AssertionError('Only one server type may be specified, not multiple types.') 270 server_type = args[0] 271 if options.trans == 'http': 272 server_type = 'THttpServer' 273 274 # Set up the handler and processor objects 275 handler = TestHandler() 276 processor = ThriftTest.Processor(handler) 277 278 if options.proto.startswith('multi'): 279 secondHandler = SecondHandler() 280 secondProcessor = SecondService.Processor(secondHandler) 281 282 multiplexedProcessor = TMultiplexedProcessor() 283 multiplexedProcessor.registerDefault(processor) 284 multiplexedProcessor.registerProcessor('ThriftTest', processor) 285 multiplexedProcessor.registerProcessor('SecondService', secondProcessor) 286 processor = multiplexedProcessor 287 288 global server 289 290 # Handle THttpServer as a special case 291 if server_type == 'THttpServer': 292 if options.ssl: 293 __certfile = os.path.join(os.path.dirname(SCRIPT_DIR), "keys", "server.crt") 294 __keyfile = os.path.join(os.path.dirname(SCRIPT_DIR), "keys", "server.key") 295 server = THttpServer.THttpServer(processor, ('', options.port), pfactory, cert_file=__certfile, key_file=__keyfile) 296 else: 297 server = THttpServer.THttpServer(processor, ('', options.port), pfactory) 298 server.serve() 299 sys.exit(0) 300 301 # set up server transport and transport factory 302 303 abs_key_path = os.path.join(os.path.dirname(SCRIPT_DIR), 'keys', 'server.pem') 304 305 host = None 306 if options.ssl: 307 from thrift.transport import TSSLSocket 308 transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path) 309 else: 310 transport = TSocket.TServerSocket(host, options.port, options.domain_socket) 311 tfactory = TTransport.TBufferedTransportFactory() 312 if options.trans == 'buffered': 313 tfactory = TTransport.TBufferedTransportFactory() 314 elif options.trans == 'framed': 315 tfactory = TTransport.TFramedTransportFactory() 316 elif options.trans == '': 317 raise AssertionError('Unknown --transport option: %s' % options.trans) 318 else: 319 tfactory = TTransport.TBufferedTransportFactory() 320 # if --zlib, then wrap server transport, and use a different transport factory 321 if options.zlib: 322 transport = TZlibTransport.TZlibTransport(transport) # wrap with zlib 323 tfactory = TZlibTransport.TZlibTransportFactory() 324 325 # do server-specific setup here: 326 if server_type == "TNonblockingServer": 327 server = TNonblockingServer.TNonblockingServer(processor, transport, inputProtocolFactory=pfactory) 328 elif server_type == "TProcessPoolServer": 329 import signal 330 from thrift.server import TProcessPoolServer 331 server = TProcessPoolServer.TProcessPoolServer(processor, transport, tfactory, pfactory) 332 server.setNumWorkers(5) 333 334 def set_alarm(): 335 def clean_shutdown(signum, frame): 336 for worker in server.workers: 337 if options.verbose > 0: 338 logging.info('Terminating worker: %s' % worker) 339 worker.terminate() 340 if options.verbose > 0: 341 logging.info('Requesting server to stop()') 342 try: 343 server.stop() 344 except Exception: 345 pass 346 signal.signal(signal.SIGALRM, clean_shutdown) 347 signal.alarm(4) 348 set_alarm() 349 else: 350 # look up server class dynamically to instantiate server 351 ServerClass = getattr(TServer, server_type) 352 server = ServerClass(processor, transport, tfactory, pfactory) 353 # enter server main loop 354 server.serve() 355 356 357def exit_gracefully(signum, frame): 358 print("SIGINT received\n") 359 server.shutdown() # doesn't work properly, yet 360 sys.exit(0) 361 362 363if __name__ == '__main__': 364 signal.signal(signal.SIGINT, exit_gracefully) 365 366 parser = OptionParser() 367 parser.add_option('--libpydir', type='string', dest='libpydir', 368 help='include this directory to sys.path for locating library code') 369 parser.add_option('--genpydir', type='string', dest='genpydir', 370 default='gen-py', 371 help='include this directory to sys.path for locating generated code') 372 parser.add_option("--port", type="int", dest="port", 373 help="port number for server to listen on") 374 parser.add_option("--zlib", action="store_true", dest="zlib", 375 help="use zlib wrapper for compressed transport") 376 parser.add_option("--ssl", action="store_true", dest="ssl", 377 help="use SSL for encrypted transport") 378 parser.add_option('-v', '--verbose', action="store_const", 379 dest="verbose", const=2, 380 help="verbose output") 381 parser.add_option('-q', '--quiet', action="store_const", 382 dest="verbose", const=0, 383 help="minimal output") 384 parser.add_option('--protocol', dest="proto", type="string", 385 help="protocol to use, one of: accel, accelc, binary, compact, json, multi, multia, multiac, multic, multih, multij") 386 parser.add_option('--transport', dest="trans", type="string", 387 help="transport to use, one of: buffered, framed, http") 388 parser.add_option('--domain-socket', dest="domain_socket", type="string", 389 help="Unix domain socket path") 390 parser.add_option('--container-limit', dest='container_limit', type='int', default=None) 391 parser.add_option('--string-limit', dest='string_limit', type='int', default=None) 392 parser.set_defaults(port=9090, verbose=1, proto='binary', transport='buffered') 393 options, args = parser.parse_args() 394 395 # Print TServer log to stdout so that the test-runner can redirect it to log files 396 logging.basicConfig(level=options.verbose) 397 398 sys.path.insert(0, os.path.join(SCRIPT_DIR, options.genpydir)) 399 400 from ThriftTest import ThriftTest, SecondService 401 from ThriftTest.ttypes import Xtruct, Xception, Xception2, Insanity 402 from thrift.Thrift import TException 403 from thrift.TMultiplexedProcessor import TMultiplexedProcessor 404 from thrift.transport import THeaderTransport 405 from thrift.transport import TTransport 406 from thrift.transport import TSocket 407 from thrift.transport import TZlibTransport 408 from thrift.protocol import TBinaryProtocol 409 from thrift.protocol import TCompactProtocol 410 from thrift.protocol import THeaderProtocol 411 from thrift.protocol import TJSONProtocol 412 from thrift.server import TServer, TNonblockingServer, THttpServer 413 414 sys.exit(main(options)) 415