1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package common
21
22import (
23	"compress/zlib"
24	"crypto/tls"
25	"flag"
26	"fmt"
27
28	"github.com/apache/thrift/lib/go/thrift"
29	"github.com/apache/thrift/test/go/src/gen/thrifttest"
30)
31
32var (
33	debugServerProtocol bool
34)
35
36func init() {
37	flag.BoolVar(&debugServerProtocol, "debug_server_protocol", false, "turn server protocol trace on")
38}
39
40func GetServerParams(
41	host string,
42	port int64,
43	domain_socket string,
44	transport string,
45	protocol string,
46	ssl bool,
47	certPath string,
48	handler thrifttest.ThriftTest,
49) (thrift.TProcessor, thrift.TServerTransport, thrift.TTransportFactory, thrift.TProtocolFactory, error) {
50
51	var err error
52	hostPort := fmt.Sprintf("%s:%d", host, port)
53	var cfg *thrift.TConfiguration = nil
54
55	var protocolFactory thrift.TProtocolFactory
56	switch protocol {
57	case "compact":
58		protocolFactory = thrift.NewTCompactProtocolFactoryConf(cfg)
59	case "simplejson":
60		protocolFactory = thrift.NewTSimpleJSONProtocolFactoryConf(cfg)
61	case "json":
62		protocolFactory = thrift.NewTJSONProtocolFactory()
63	case "binary":
64		protocolFactory = thrift.NewTBinaryProtocolFactoryConf(nil)
65	case "header":
66		protocolFactory = thrift.NewTHeaderProtocolFactoryConf(nil)
67	default:
68		return nil, nil, nil, nil, fmt.Errorf("invalid protocol specified %s", protocol)
69	}
70	if debugServerProtocol {
71		protocolFactory = thrift.NewTDebugProtocolFactoryWithLogger(protocolFactory, "server:", thrift.StdLogger(nil))
72	}
73
74	var serverTransport thrift.TServerTransport
75	if ssl {
76		cfg := new(tls.Config)
77		if cert, err := tls.LoadX509KeyPair(certPath+"/server.crt", certPath+"/server.key"); err != nil {
78			return nil, nil, nil, nil, err
79		} else {
80			cfg.Certificates = append(cfg.Certificates, cert)
81		}
82		serverTransport, err = thrift.NewTSSLServerSocket(hostPort, cfg)
83	} else {
84		if domain_socket != "" {
85			serverTransport, err = thrift.NewTServerSocket(domain_socket)
86		} else {
87			serverTransport, err = thrift.NewTServerSocket(hostPort)
88		}
89	}
90	if err != nil {
91		return nil, nil, nil, nil, err
92	}
93
94	var transportFactory thrift.TTransportFactory
95
96	switch transport {
97	case "http":
98		// there is no such factory, and we don't need any
99		transportFactory = nil
100	case "framed":
101		transportFactory = thrift.NewTTransportFactory()
102		transportFactory = thrift.NewTFramedTransportFactoryConf(transportFactory, nil)
103	case "buffered":
104		transportFactory = thrift.NewTBufferedTransportFactory(8192)
105	case "zlib":
106		transportFactory = thrift.NewTZlibTransportFactory(zlib.BestCompression)
107	case "":
108		transportFactory = thrift.NewTTransportFactory()
109	default:
110		return nil, nil, nil, nil, fmt.Errorf("invalid transport specified %s", transport)
111	}
112	processor := thrifttest.NewThriftTestProcessor(handler)
113
114	return processor, serverTransport, transportFactory, protocolFactory, nil
115}
116