1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import copy
21import multiprocessing
22import os
23import sys
24from .compat import path_join
25from .util import merge_dict, domain_socket_path
26
27
28class TestProgram(object):
29    def __init__(self, kind, name, protocol, transport, socket, workdir, stop_signal, command, env=None,
30                 extra_args=[], extra_args2=[], join_args=False, **kwargs):
31
32        self.kind = kind
33        self.name = name
34        self.protocol = protocol
35        self.transport = transport
36        self.socket = socket
37        self.workdir = workdir
38        self.stop_signal = stop_signal
39        self.command = None
40        self._base_command = self._fix_cmd_path(command)
41        if env:
42            self.env = copy.copy(os.environ)
43            self.env.update(env)
44        else:
45            self.env = os.environ
46        self._extra_args = extra_args
47        self._extra_args2 = extra_args2
48        self._join_args = join_args
49
50    def _fix_cmd_path(self, cmd):
51        # if the arg is a file in the current directory, make it path
52        def abs_if_exists(arg):
53            p = path_join(self.workdir, arg)
54            return p if os.path.exists(p) else arg
55
56        if cmd[0] == 'python':
57            cmd[0] = sys.executable
58        else:
59            cmd[0] = abs_if_exists(cmd[0])
60        return cmd
61
62    def _socket_args(self, socket, port):
63        return {
64            'ip-ssl': ['--ssl'],
65            'domain': ['--domain-socket=%s' % domain_socket_path(port)],
66            'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
67        }.get(socket, None)
68
69    def _transport_args(self, transport):
70        return {
71            'zlib': ['--zlib'],
72        }.get(transport, None)
73
74    def build_command(self, port):
75        cmd = copy.copy(self._base_command)
76        args = copy.copy(self._extra_args2)
77        args.append('--protocol=' + self.protocol)
78        args.append('--transport=' + self.transport)
79        transport_args = self._transport_args(self.transport)
80        if transport_args:
81            args += transport_args
82        socket_args = self._socket_args(self.socket, port)
83        if socket_args:
84            args += socket_args
85        args.append('--port=%d' % port)
86        if self._join_args:
87            cmd.append('%s' % " ".join(args))
88        else:
89            cmd.extend(args)
90        if self._extra_args:
91            cmd.extend(self._extra_args)
92        self.command = cmd
93        return self.command
94
95
96class TestEntry(object):
97    def __init__(self, testdir, server, client, delay, timeout, **kwargs):
98        self.testdir = testdir
99        self._log = multiprocessing.get_logger()
100        self._config = kwargs
101        self.protocol = kwargs['protocol']
102        self.transport = kwargs['transport']
103        self.socket = kwargs['socket']
104        srv_dict = self._fix_workdir(merge_dict(self._config, server))
105        cli_dict = self._fix_workdir(merge_dict(self._config, client))
106        cli_dict['extra_args2'] = srv_dict.pop('remote_args', [])
107        srv_dict['extra_args2'] = cli_dict.pop('remote_args', [])
108        self.server = TestProgram('server', **srv_dict)
109        self.client = TestProgram('client', **cli_dict)
110        self.delay = delay
111        self.timeout = timeout
112        self._name = None
113        # results
114        self.success = None
115        self.as_expected = None
116        self.returncode = None
117        self.expired = False
118        self.retry_count = 0
119
120    def _fix_workdir(self, config):
121        key = 'workdir'
122        path = config.get(key, None)
123        if not path:
124            path = self.testdir
125        if os.path.isabs(path):
126            path = os.path.realpath(path)
127        else:
128            path = os.path.realpath(path_join(self.testdir, path))
129        config.update({key: path})
130        return config
131
132    @classmethod
133    def get_name(cls, server, client, protocol, transport, socket, *args, **kwargs):
134        return '%s-%s_%s_%s-%s' % (server, client, protocol, transport, socket)
135
136    @property
137    def name(self):
138        if not self._name:
139            self._name = self.get_name(
140                self.server.name, self.client.name, self.protocol, self.transport, self.socket)
141        return self._name
142
143    @property
144    def transport_name(self):
145        return '%s-%s' % (self.transport, self.socket)
146
147
148def test_name(server, client, protocol, transport, socket, **kwargs):
149    return TestEntry.get_name(server['name'], client['name'], protocol, transport, socket)
150