1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import contextlib
21import multiprocessing
22import multiprocessing.managers
23import os
24import platform
25import random
26import socket
27import subprocess
28import sys
29import time
30
31from .compat import str_join
32from .report import ExecReporter, SummaryReporter
33from .test import TestEntry
34from .util import domain_socket_path
35
36RESULT_ERROR = 64
37RESULT_TIMEOUT = 128
38SIGNONE = 0
39SIGKILL = 15
40
41# globals
42ports = None
43stop = None
44
45
46class ExecutionContext(object):
47    def __init__(self, cmd, cwd, env, stop_signal, is_server, report):
48        self._log = multiprocessing.get_logger()
49        self.cmd = cmd
50        self.cwd = cwd
51        self.env = env
52        self.stop_signal = stop_signal
53        self.is_server = is_server
54        self.report = report
55        self.expired = False
56        self.killed = False
57        self.proc = None
58
59    def _popen_args(self):
60        args = {
61            'cwd': self.cwd,
62            'env': self.env,
63            'stdout': self.report.out,
64            'stderr': subprocess.STDOUT,
65        }
66        # make sure child processes doesn't remain after killing
67        if platform.system() == 'Windows':
68            DETACHED_PROCESS = 0x00000008
69            args.update(creationflags=DETACHED_PROCESS | subprocess.CREATE_NEW_PROCESS_GROUP)
70        else:
71            args.update(preexec_fn=os.setsid)
72        return args
73
74    def start(self):
75        joined = str_join(' ', self.cmd)
76        self._log.debug('COMMAND: %s', joined)
77        self._log.debug('WORKDIR: %s', self.cwd)
78        self._log.debug('LOGFILE: %s', self.report.logpath)
79        self.report.begin()
80        self.proc = subprocess.Popen(self.cmd, **self._popen_args())
81        self._log.debug('    PID: %d', self.proc.pid)
82        self._log.debug('   PGID: %d', os.getpgid(self.proc.pid))
83        return self._scoped()
84
85    @contextlib.contextmanager
86    def _scoped(self):
87        yield self
88        if self.is_server:
89            # the server is supposed to run until we stop it
90            if self.returncode is not None:
91                self.report.died()
92            else:
93                if self.stop_signal != SIGNONE:
94                    if self.sigwait(self.stop_signal):
95                        self.report.end(self.returncode)
96                    else:
97                        self.report.killed()
98                else:
99                    self.sigwait(SIGKILL)
100        else:
101            # the client is supposed to exit normally
102            if self.returncode is not None:
103                self.report.end(self.returncode)
104            else:
105                self.sigwait(SIGKILL)
106                self.report.killed()
107        self._log.debug('[{0}] exited with return code {1}'.format(self.proc.pid, self.returncode))
108
109    # Send a signal to the process and then wait for it to end
110    # If the signal requested is SIGNONE, no signal is sent, and
111    # instead we just wait for the process to end; further if it
112    # does not end normally with SIGNONE, we mark it as expired.
113    # If the process fails to end and the signal is not SIGKILL,
114    # it re-runs with SIGKILL so that a real process kill occurs
115    # returns True if the process ended, False if it may not have
116    def sigwait(self, sig=SIGKILL, timeout=2):
117        try:
118            if sig != SIGNONE:
119                self._log.debug('[{0}] send signal {1}'.format(self.proc.pid, sig))
120                if sig == SIGKILL:
121                    self.killed = True
122                try:
123                    if platform.system() != 'Windows':
124                        os.killpg(os.getpgid(self.proc.pid), sig)
125                    else:
126                        self.proc.send_signal(sig)
127                except Exception:
128                    self._log.info('[{0}] Failed to kill process'.format(self.proc.pid), exc_info=sys.exc_info())
129            self._log.debug('[{0}] wait begin, timeout {1} sec(s)'.format(self.proc.pid, timeout))
130            self.proc.communicate(timeout=timeout)
131            self._log.debug('[{0}] process ended with return code {1}'.format(self.proc.pid, self.returncode))
132            self.report.end(self.returncode)
133            return True
134        except subprocess.TimeoutExpired:
135            self._log.info('[{0}] timeout waiting for process to end'.format(self.proc.pid))
136            if sig == SIGNONE:
137                self.expired = True
138            return False if sig == SIGKILL else self.sigwait(SIGKILL, 1)
139
140    # called on the client process to wait for it to end naturally
141    def wait(self, timeout):
142        self.sigwait(SIGNONE, timeout)
143
144    @property
145    def returncode(self):
146        return self.proc.returncode if self.proc else None
147
148
149def exec_context(port, logdir, test, prog, is_server):
150    report = ExecReporter(logdir, test, prog)
151    prog.build_command(port)
152    return ExecutionContext(prog.command, prog.workdir, prog.env, prog.stop_signal, is_server, report)
153
154
155def run_test(testdir, logdir, test_dict, max_retry, async_mode=True):
156    logger = multiprocessing.get_logger()
157
158    def ensure_socket_open(sv, port, test):
159        slept = 0.1
160        time.sleep(slept)
161        sleep_step = 0.1
162        while True:
163            if slept > test.delay:
164                logger.warn('[{0}] slept for {1} seconds but server is not open'.format(sv.proc.pid, slept))
165                return False
166            if test.socket == 'domain':
167                if not os.path.exists(domain_socket_path(port)):
168                    logger.debug('[{0}] domain(unix) socket not available yet. slept for {1} seconds so far'.format(sv.proc.pid, slept))
169                    time.sleep(sleep_step)
170                    slept += sleep_step
171            elif test.socket == 'abstract':
172                return True
173            else:
174                # Create sockets every iteration because refused sockets cannot be
175                # reused on some systems.
176                sock4 = socket.socket()
177                sock6 = socket.socket(family=socket.AF_INET6)
178                try:
179                    if sock4.connect_ex(('127.0.0.1', port)) == 0 \
180                            or sock6.connect_ex(('::1', port)) == 0:
181                        return True
182                    if sv.proc.poll() is not None:
183                        logger.warn('[{0}] server process is exited'.format(sv.proc.pid))
184                        return False
185                    logger.debug('[{0}] socket not available yet. slept for {1} seconds so far'.format(sv.proc.pid, slept))
186                    time.sleep(sleep_step)
187                    slept += sleep_step
188                finally:
189                    sock4.close()
190                    sock6.close()
191            logger.debug('[{0}] server ready - waited for {1} seconds'.format(sv.proc.pid, slept))
192            return True
193
194    try:
195        max_bind_retry = 3
196        retry_count = 0
197        bind_retry_count = 0
198        test = TestEntry(testdir, **test_dict)
199        while True:
200            if stop.is_set():
201                logger.debug('Skipping because shutting down')
202                return (retry_count, None)
203            logger.debug('Start')
204            with PortAllocator.alloc_port_scoped(ports, test.socket) as port:
205                logger.debug('Start with port %d' % port)
206                sv = exec_context(port, logdir, test, test.server, True)
207                cl = exec_context(port, logdir, test, test.client, False)
208
209                logger.debug('Starting server')
210                with sv.start():
211                    port_ok = ensure_socket_open(sv, port, test)
212                    if port_ok:
213                        connect_retry_count = 0
214                        max_connect_retry = 12
215                        connect_retry_wait = 0.25
216                        while True:
217                            if sv.proc.poll() is not None:
218                                logger.info('not starting client because server process is absent')
219                                break
220                            logger.debug('Starting client')
221                            cl.start()
222                            logger.debug('Waiting client (up to %d secs)' % test.timeout)
223                            cl.wait(test.timeout)
224                            if not cl.report.maybe_false_positive() or connect_retry_count >= max_connect_retry:
225                                if connect_retry_count > 0 and connect_retry_count < max_connect_retry:
226                                    logger.info('[%s]: Connected after %d retry (%.2f sec each)' % (test.server.name, connect_retry_count, connect_retry_wait))
227                                # Wait for 50ms to see if server does not die at the end.
228                                time.sleep(0.05)
229                                break
230                            logger.debug('Server may not be ready, waiting %.2f second...' % connect_retry_wait)
231                            time.sleep(connect_retry_wait)
232                            connect_retry_count += 1
233
234            if sv.report.maybe_false_positive() and bind_retry_count < max_bind_retry:
235                logger.warn('[%s]: Detected socket bind failure, retrying...', test.server.name)
236                bind_retry_count += 1
237            else:
238                result = RESULT_TIMEOUT if cl.expired else cl.returncode if (cl.proc and cl.proc.poll()) is not None else RESULT_ERROR
239
240                # For servers that handle a controlled shutdown by signal
241                # if they are killed, or return an error code, that is a
242                # problem.  For servers that are not signal-aware, we simply
243                # kill them off; if we didn't kill them off, something else
244                # happened (crashed?)
245                if test.server.stop_signal != 0:
246                    # for bash scripts, 128+N is the exit code for signal N, since we are sending
247                    # DEFAULT_SIGNAL=1, 128 + 1 is the expected err code
248                    # http://www.gnu.org/software/bash/manual/html_node/Exit-Status.html
249                    allowed_return_code = set([-1, 0, 128 + 1])
250                    if sv.killed or sv.returncode not in allowed_return_code:
251                        result |= RESULT_ERROR
252                else:
253                    if not sv.killed:
254                        result |= RESULT_ERROR
255
256                if result == 0 or retry_count >= max_retry:
257                    return (retry_count, result)
258                else:
259                    logger.info('[%s-%s]: test failed, retrying...', test.server.name, test.client.name)
260                    retry_count += 1
261    except Exception:
262        if not async_mode:
263            raise
264        logger.warn('Error executing [%s]', test.name, exc_info=True)
265        return (retry_count, RESULT_ERROR)
266    except Exception:
267        logger.info('Interrupted execution', exc_info=True)
268        if not async_mode:
269            raise
270        stop.set()
271        return (retry_count, RESULT_ERROR)
272
273
274class PortAllocator(object):
275    def __init__(self):
276        self._log = multiprocessing.get_logger()
277        self._lock = multiprocessing.Lock()
278        self._ports = set()
279        self._dom_ports = set()
280        self._last_alloc = 0
281
282    def _get_tcp_port(self):
283        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
284        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
285        sock.bind(('', 0))
286        port = sock.getsockname()[1]
287        self._lock.acquire()
288        try:
289            ok = port not in self._ports
290            if ok:
291                self._ports.add(port)
292                self._last_alloc = time.time()
293        finally:
294            self._lock.release()
295            sock.close()
296        return port if ok else self._get_tcp_port()
297
298    def _get_domain_port(self):
299        port = random.randint(1024, 65536)
300        self._lock.acquire()
301        try:
302            ok = port not in self._dom_ports
303            if ok:
304                self._dom_ports.add(port)
305        finally:
306            self._lock.release()
307        return port if ok else self._get_domain_port()
308
309    def alloc_port(self, socket_type):
310        if socket_type in ('domain', 'abstract'):
311            return self._get_domain_port()
312        else:
313            return self._get_tcp_port()
314
315    # static method for inter-process invokation
316    @staticmethod
317    @contextlib.contextmanager
318    def alloc_port_scoped(allocator, socket_type):
319        port = allocator.alloc_port(socket_type)
320        yield port
321        allocator.free_port(socket_type, port)
322
323    def free_port(self, socket_type, port):
324        self._log.debug('free_port')
325        self._lock.acquire()
326        try:
327            if socket_type == 'domain':
328                self._dom_ports.remove(port)
329                path = domain_socket_path(port)
330                if os.path.exists(path):
331                    os.remove(path)
332            elif socket_type == 'abstract':
333                self._dom_ports.remove(port)
334            else:
335                self._ports.remove(port)
336        except IOError:
337            self._log.info('Error while freeing port', exc_info=sys.exc_info())
338        finally:
339            self._lock.release()
340
341
342class NonAsyncResult(object):
343    def __init__(self, value):
344        self._value = value
345
346    def get(self, timeout=None):
347        return self._value
348
349    def wait(self, timeout=None):
350        pass
351
352    def ready(self):
353        return True
354
355    def successful(self):
356        return self._value == 0
357
358
359class TestDispatcher(object):
360    def __init__(self, testdir, basedir, logdir_rel, concurrency):
361        self._log = multiprocessing.get_logger()
362        self.testdir = testdir
363        self._report = SummaryReporter(basedir, logdir_rel, concurrency > 1)
364        self.logdir = self._report.testdir
365        # seems needed for python 2.x to handle keyboard interrupt
366        self._stop = multiprocessing.Event()
367        self._async = concurrency > 1
368        if not self._async:
369            self._pool = None
370            global stop
371            global ports
372            stop = self._stop
373            ports = PortAllocator()
374        else:
375            self._m = multiprocessing.managers.BaseManager()
376            self._m.register('ports', PortAllocator)
377            self._m.start()
378            self._pool = multiprocessing.Pool(concurrency, self._pool_init, (self._m.address,))
379        self._log.debug(
380            'TestDispatcher started with %d concurrent jobs' % concurrency)
381
382    def _pool_init(self, address):
383        global stop
384        global m
385        global ports
386        stop = self._stop
387        m = multiprocessing.managers.BaseManager(address)
388        m.connect()
389        ports = m.ports()
390
391    def _dispatch_sync(self, test, cont, max_retry):
392        r = run_test(self.testdir, self.logdir, test, max_retry, async_mode=False)
393        cont(r)
394        return NonAsyncResult(r)
395
396    def _dispatch_async(self, test, cont, max_retry):
397        self._log.debug('_dispatch_async')
398        return self._pool.apply_async(func=run_test, args=(self.testdir, self.logdir, test, max_retry), callback=cont)
399
400    def dispatch(self, test, max_retry):
401        index = self._report.add_test(test)
402
403        def cont(result):
404            if not self._stop.is_set():
405                if result and len(result) == 2:
406                    retry_count, returncode = result
407                else:
408                    retry_count = 0
409                    returncode = RESULT_ERROR
410                self._log.debug('freeing port')
411                self._log.debug('adding result')
412                self._report.add_result(index, returncode, returncode == RESULT_TIMEOUT, retry_count)
413                self._log.debug('finish continuation')
414        fn = self._dispatch_async if self._async else self._dispatch_sync
415        return fn(test, cont, max_retry)
416
417    def wait(self):
418        if self._async:
419            self._pool.close()
420            self._pool.join()
421            self._m.shutdown()
422        return self._report.end()
423
424    def terminate(self):
425        self._stop.set()
426        if self._async:
427            self._pool.terminate()
428            self._pool.join()
429            self._m.shutdown()
430