1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20package common 21 22import ( 23 "compress/zlib" 24 "crypto/tls" 25 "flag" 26 "fmt" 27 28 "github.com/apache/thrift/lib/go/thrift" 29 "github.com/apache/thrift/test/go/src/gen/thrifttest" 30) 31 32var ( 33 debugServerProtocol bool 34) 35 36func init() { 37 flag.BoolVar(&debugServerProtocol, "debug_server_protocol", false, "turn server protocol trace on") 38} 39 40func GetServerParams( 41 host string, 42 port int64, 43 domain_socket string, 44 transport string, 45 protocol string, 46 ssl bool, 47 certPath string, 48 handler thrifttest.ThriftTest, 49) (thrift.TProcessor, thrift.TServerTransport, thrift.TTransportFactory, thrift.TProtocolFactory, error) { 50 51 var err error 52 hostPort := fmt.Sprintf("%s:%d", host, port) 53 var cfg *thrift.TConfiguration = nil 54 55 var protocolFactory thrift.TProtocolFactory 56 switch protocol { 57 case "compact": 58 protocolFactory = thrift.NewTCompactProtocolFactoryConf(cfg) 59 case "simplejson": 60 protocolFactory = thrift.NewTSimpleJSONProtocolFactoryConf(cfg) 61 case "json": 62 protocolFactory = thrift.NewTJSONProtocolFactory() 63 case "binary": 64 protocolFactory = thrift.NewTBinaryProtocolFactoryConf(nil) 65 case "header": 66 protocolFactory = thrift.NewTHeaderProtocolFactoryConf(nil) 67 default: 68 return nil, nil, nil, nil, fmt.Errorf("invalid protocol specified %s", protocol) 69 } 70 if debugServerProtocol { 71 protocolFactory = thrift.NewTDebugProtocolFactoryWithLogger(protocolFactory, "server:", thrift.StdLogger(nil)) 72 } 73 74 var serverTransport thrift.TServerTransport 75 if ssl { 76 cfg := new(tls.Config) 77 if cert, err := tls.LoadX509KeyPair(certPath+"/server.crt", certPath+"/server.key"); err != nil { 78 return nil, nil, nil, nil, err 79 } else { 80 cfg.Certificates = append(cfg.Certificates, cert) 81 } 82 serverTransport, err = thrift.NewTSSLServerSocket(hostPort, cfg) 83 } else { 84 if domain_socket != "" { 85 serverTransport, err = thrift.NewTServerSocket(domain_socket) 86 } else { 87 serverTransport, err = thrift.NewTServerSocket(hostPort) 88 } 89 } 90 if err != nil { 91 return nil, nil, nil, nil, err 92 } 93 94 var transportFactory thrift.TTransportFactory 95 96 switch transport { 97 case "http": 98 // there is no such factory, and we don't need any 99 transportFactory = nil 100 case "framed": 101 transportFactory = thrift.NewTTransportFactory() 102 transportFactory = thrift.NewTFramedTransportFactoryConf(transportFactory, nil) 103 case "buffered": 104 transportFactory = thrift.NewTBufferedTransportFactory(8192) 105 case "zlib": 106 transportFactory = thrift.NewTZlibTransportFactory(zlib.BestCompression) 107 case "": 108 transportFactory = thrift.NewTTransportFactory() 109 default: 110 return nil, nil, nil, nil, fmt.Errorf("invalid transport specified %s", transport) 111 } 112 processor := thrifttest.NewThriftTestProcessor(handler) 113 114 return processor, serverTransport, transportFactory, protocolFactory, nil 115} 116