From f20f776dc336f5556c949fd7dc73e5a7dd10c8cf Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 27 May 2022 15:21:35 +0100 Subject: [PATCH] Use asyncio.run to run gen_cluster, gen_test and cluster (#6231) Closes https://github.com/dask/distributed/issues/6164 --- distributed/utils_test.py | 554 +++++++++++++++++++++----------------- 1 file changed, 301 insertions(+), 253 deletions(-) diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 46e795fd3c..edd74bef17 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -19,6 +19,7 @@ import sys import tempfile import threading +import warnings import weakref from collections import defaultdict from collections.abc import Callable @@ -30,7 +31,6 @@ import pytest import yaml from tlz import assoc, memoize, merge -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -136,7 +136,7 @@ async def cleanup_global_workers(): @pytest.fixture -def loop(): +def loop(cleanup): with check_instances(): with pristine_loop() as loop: # Monkey-patch IOLoop.start to wait for loop stop @@ -169,17 +169,35 @@ def start(): @pytest.fixture -def loop_in_thread(): - with pristine_loop() as loop: - thread = threading.Thread(target=loop.start, name="test IOLoop") - thread.daemon = True - thread.start() - loop_started = threading.Event() - loop.add_callback(loop_started.set) - loop_started.wait() - yield loop - loop.add_callback(loop.stop) - thread.join(timeout=5) +def loop_in_thread(cleanup): + loop_started = concurrent.futures.Future() + with concurrent.futures.ThreadPoolExecutor( + 1, thread_name_prefix="test IOLoop" + ) as tpe: + + async def run(): + io_loop = IOLoop.current() + stop_event = asyncio.Event() + loop_started.set_result((io_loop, stop_event)) + await stop_event.wait() + + # run asyncio.run in a thread and collect exceptions from *either* + # the loop failing to start, or failing to close + ran = tpe.submit(_run_and_close_tornado, run) + for f in concurrent.futures.as_completed((loop_started, ran)): + if f is loop_started: + io_loop, stop_event = loop_started.result() + try: + yield io_loop + finally: + io_loop.add_callback(stop_event.set) + + elif f is ran: + # if this is the first iteration the loop failed to start + # if it's the second iteration the loop has finished or + # the loop failed to close and we need to raise the exception + ran.result() + return @pytest.fixture @@ -442,29 +460,37 @@ async def background_read(): return msg +def _run_and_close_tornado(async_fn, /, *args, **kwargs): + tornado_loop = None + + async def inner_fn(): + nonlocal tornado_loop + tornado_loop = IOLoop.current() + return await async_fn(*args, **kwargs) + + try: + return asyncio.run(inner_fn()) + finally: + tornado_loop.close(all_fds=True) + + def run_scheduler(q, nputs, config, port=0, **kwargs): with dask.config.set(config): - # On Python 2.7 and Unix, fork() is used to spawn child processes, - # so avoid inheriting the parent's IO loop. - with pristine_loop() as loop: - - async def _(): - try: - scheduler = await Scheduler( - validate=True, host="127.0.0.1", port=port, **kwargs - ) - except Exception as exc: - for i in range(nputs): - q.put(exc) - else: - for i in range(nputs): - q.put(scheduler.address) - await scheduler.finished() + async def _(): try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) + scheduler = await Scheduler( + validate=True, host="127.0.0.1", port=port, **kwargs + ) + except Exception as exc: + for i in range(nputs): + q.put(exc) + else: + for i in range(nputs): + q.put(scheduler.address) + await scheduler.finished() + + _run_and_close_tornado(_) def run_worker(q, scheduler_q, config, **kwargs): @@ -473,37 +499,12 @@ def run_worker(q, scheduler_q, config, **kwargs): reset_logger_locks() with log_errors(): - with pristine_loop() as loop: - scheduler_addr = scheduler_q.get() - - async def _(): - pid = os.getpid() - try: - worker = await Worker(scheduler_addr, validate=True, **kwargs) - except Exception as exc: - q.put((pid, exc)) - else: - q.put((pid, worker.address)) - await worker.finished() - - # Scheduler might've failed - if isinstance(scheduler_addr, str): - try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) - - -@log_errors -def run_nanny(q, scheduler_q, config, **kwargs): - with dask.config.set(config): - with pristine_loop() as loop: scheduler_addr = scheduler_q.get() async def _(): pid = os.getpid() try: - worker = await Nanny(scheduler_addr, validate=True, **kwargs) + worker = await Worker(scheduler_addr, validate=True, **kwargs) except Exception as exc: q.put((pid, exc)) else: @@ -512,14 +513,36 @@ async def _(): # Scheduler might've failed if isinstance(scheduler_addr, str): - try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) + _run_and_close_tornado(_) + + +@log_errors +def run_nanny(q, scheduler_q, config, **kwargs): + with dask.config.set(config): + scheduler_addr = scheduler_q.get() + + async def _(): + pid = os.getpid() + try: + worker = await Nanny(scheduler_addr, validate=True, **kwargs) + except Exception as exc: + q.put((pid, exc)) + else: + q.put((pid, worker.address)) + await worker.finished() + + # Scheduler might've failed + if isinstance(scheduler_addr, str): + _run_and_close_tornado(_) @contextmanager def check_active_rpc(loop, active_rpc_timeout=1): + warnings.warn( + "check_active_rpc is deprecated - use gen_test()", + DeprecationWarning, + stacklevel=2, + ) active_before = set(rpc.active) yield # Some streams can take a bit of time to notice their peer @@ -544,6 +567,29 @@ async def wait(): loop.run_sync(wait) +@contextlib.asynccontextmanager +async def _acheck_active_rpc(active_rpc_timeout=1): + active_before = set(rpc.active) + yield + # Some streams can take a bit of time to notice their peer + # has closed, and keep a coroutine (*) waiting for a CommClosedError + # before calling close_rpc() after a CommClosedError. + # This would happen especially if a non-localhost address is used, + # as Nanny does. + # (*) (example: gather_from_workers()) + + def fail(): + pytest.fail( + "some RPCs left active by test: %s" % (set(rpc.active) - active_before) + ) + + await async_wait_for( + lambda: len(set(rpc.active) - active_before) == 0, + timeout=active_rpc_timeout, + fail_func=fail, + ) + + @pytest.fixture def cluster_fixture(loop): with cluster() as (scheduler, workers): @@ -656,7 +702,7 @@ def cluster( ws = weakref.WeakSet() enable_proctitle_on_children() - with clean(timeout=active_rpc_timeout, threads=False) as loop: + with check_process_leak(check=True), check_instances(), _reconfigure(): if nanny: _run_worker = run_nanny else: @@ -736,7 +782,7 @@ async def wait_for_workers(): if time() - start > 5: raise Exception("Timeout on cluster creation") - loop.run_sync(wait_for_workers) + _run_and_close_tornado(wait_for_workers) # avoid sending processes down to function yield {"address": saddr}, [ @@ -744,26 +790,25 @@ async def wait_for_workers(): for w in workers_by_pid.values() ] finally: - logger.debug("Closing out test cluster") - alive_workers = [ - w["address"] - for w in workers_by_pid.values() - if w["proc"].is_alive() - ] - loop.run_sync( - lambda: disconnect_all( + + async def close(): + logger.debug("Closing out test cluster") + alive_workers = [ + w["address"] + for w in workers_by_pid.values() + if w["proc"].is_alive() + ] + await disconnect_all( alive_workers, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs, ) - ) - if scheduler.is_alive(): - loop.run_sync( - lambda: disconnect( + if scheduler.is_alive(): + await disconnect( saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs ) - ) + _run_and_close_tornado(close) try: client = default_client() except ValueError: @@ -815,17 +860,20 @@ async def test_foo(): if is_debugging(): timeout = 3600 + async def async_fn_outer(async_fn, /, *args, **kwargs): + async with _acheck_active_rpc(): + return await asyncio.wait_for( + asyncio.create_task(async_fn(*args, **kwargs)), timeout + ) + def _(func): @functools.wraps(func) + @clean(**clean_kwargs) def test_func(*args, **kwargs): - with clean(**clean_kwargs) as loop: - injected_func = functools.partial(func, *args, **kwargs) - if iscoroutinefunction(func): - cor = injected_func - else: - cor = gen.coroutine(injected_func) + if not iscoroutinefunction(func): + raise RuntimeError("gen_test only works for coroutine functions.") - loop.run_sync(cor, timeout=timeout) + return _run_and_close_tornado(async_fn_outer, func, *args, **kwargs) # Patch the signature so pytest can inject fixtures test_func.__signature__ = inspect.signature(func) @@ -837,7 +885,7 @@ def test_func(*args, **kwargs): async def start_cluster( nthreads: list[tuple[str, int] | tuple[str, int, dict]], scheduler_addr: str, - loop: IOLoop, + loop: IOLoop | None = None, security: Security | dict[str, Any] | None = None, Worker: type[ServerNode] = Worker, scheduler_kwargs: dict[str, Any] = {}, @@ -1014,161 +1062,156 @@ def _(func): raise RuntimeError("gen_cluster only works for coroutine functions.") @functools.wraps(func) + @clean(**clean_kwargs) def test_func(*outer_args, **kwargs): - result = None - with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop: - - async def coro(): - with tempfile.TemporaryDirectory() as tmpdir: - config2 = merge({"temporary-directory": tmpdir}, config) - with dask.config.set(config2): - workers = [] - s = False - - for _ in range(60): - try: - s, ws = await start_cluster( - nthreads, - scheduler, - loop, - security=security, - Worker=Worker, - scheduler_kwargs=scheduler_kwargs, - worker_kwargs=worker_kwargs, - ) - except Exception as e: - logger.error( - "Failed to start gen_cluster: " - f"{e.__class__.__name__}: {e}; retrying", - exc_info=True, - ) - await asyncio.sleep(1) - else: - workers[:] = ws - args = [s] + workers - break - if s is False: - raise Exception("Could not start cluster") - if client: - c = await Client( - s.address, - loop=loop, + async def async_fn(): + result = None + with tempfile.TemporaryDirectory() as tmpdir: + config2 = merge({"temporary-directory": tmpdir}, config) + with dask.config.set(config2): + workers = [] + s = False + + for _ in range(60): + try: + s, ws = await start_cluster( + nthreads, + scheduler, security=security, - asynchronous=True, - **client_kwargs, + Worker=Worker, + scheduler_kwargs=scheduler_kwargs, + worker_kwargs=worker_kwargs, ) - args = [c] + args - - try: - coro = func(*args, *outer_args, **kwargs) - task = asyncio.create_task(coro) - coro2 = asyncio.wait_for(asyncio.shield(task), timeout) - result = await coro2 - validate_state(s, *workers) - - except asyncio.TimeoutError: - assert task - buffer = io.StringIO() - # This stack indicates where the coro/test is suspended - task.print_stack(file=buffer) - - if cluster_dump_directory: - await dump_cluster_state( - s, - ws, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - - task.cancel() - while not task.cancelled(): - await asyncio.sleep(0.01) - - # Hopefully, the hang has been caused by inconsistent - # state, which should be much more meaningful than the - # timeout - validate_state(s, *workers) - - # Remove as much of the traceback as possible; it's - # uninteresting boilerplate from utils_test and asyncio - # and not from the code being tested. - raise asyncio.TimeoutError( - f"Test timeout after {timeout}s.\n" - "========== Test stack trace starts here ==========\n" - f"{buffer.getvalue()}" - ) from None - - except pytest.xfail.Exception: - raise - - except Exception: - if cluster_dump_directory and not has_pytestmark( - test_func, "xfail" - ): - await dump_cluster_state( - s, - ws, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - raise - - finally: - if client and c.status not in ("closing", "closed"): - await c._close(fast=s.status == Status.closed) - await end_cluster(s, workers) - await asyncio.wait_for(cleanup_global_workers(), 1) - - try: - c = await default_client() - except ValueError: - pass + except Exception as e: + logger.error( + "Failed to start gen_cluster: " + f"{e.__class__.__name__}: {e}; retrying", + exc_info=True, + ) + await asyncio.sleep(1) else: - await c._close(fast=True) - - def get_unclosed(): - return [ - c for c in Comm._instances if not c.closed() - ] + [ - c - for c in _global_clients.values() - if c.status != "closed" - ] + workers[:] = ws + args = [s] + workers + break + if s is False: + raise Exception("Could not start cluster") + if client: + c = await Client( + s.address, + security=security, + asynchronous=True, + **client_kwargs, + ) + args = [c] + args + + try: + coro = func(*args, *outer_args, **kwargs) + task = asyncio.create_task(coro) + coro2 = asyncio.wait_for(asyncio.shield(task), timeout) + result = await coro2 + validate_state(s, *workers) + + except asyncio.TimeoutError: + assert task + buffer = io.StringIO() + # This stack indicates where the coro/test is suspended + task.print_stack(file=buffer) + + if cluster_dump_directory: + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) - try: - start = time() - while time() < start + 60: - gc.collect() - if not get_unclosed(): - break - await asyncio.sleep(0.05) + task.cancel() + while not task.cancelled(): + await asyncio.sleep(0.01) + + # Hopefully, the hang has been caused by inconsistent + # state, which should be much more meaningful than the + # timeout + validate_state(s, *workers) + + # Remove as much of the traceback as possible; it's + # uninteresting boilerplate from utils_test and asyncio + # and not from the code being tested. + raise asyncio.TimeoutError( + f"Test timeout after {timeout}s.\n" + "========== Test stack trace starts here ==========\n" + f"{buffer.getvalue()}" + ) from None + + except pytest.xfail.Exception: + raise + + except Exception: + if cluster_dump_directory and not has_pytestmark( + test_func, "xfail" + ): + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) + raise + + finally: + if client and c.status not in ("closing", "closed"): + await c._close(fast=s.status == Status.closed) + await end_cluster(s, workers) + await asyncio.wait_for(cleanup_global_workers(), 1) + + try: + c = await default_client() + except ValueError: + pass + else: + await c._close(fast=True) + + def get_unclosed(): + return [c for c in Comm._instances if not c.closed()] + [ + c + for c in _global_clients.values() + if c.status != "closed" + ] + + try: + start = time() + while time() < start + 60: + gc.collect() + if not get_unclosed(): + break + await asyncio.sleep(0.05) + else: + if allow_unclosed: + print(f"Unclosed Comms: {get_unclosed()}") else: - if allow_unclosed: - print(f"Unclosed Comms: {get_unclosed()}") - else: - raise RuntimeError( - "Unclosed Comms", get_unclosed() - ) - finally: - Comm._instances.clear() - _global_clients.clear() - - for w in workers: - if getattr(w, "data", None): - try: - w.data.clear() - except OSError: - # zict backends can fail if their storage directory - # was already removed - pass - - return result - - result = loop.run_sync( - coro, timeout=timeout * 2 if timeout else timeout - ) - - return result + raise RuntimeError("Unclosed Comms", get_unclosed()) + finally: + Comm._instances.clear() + _global_clients.clear() + + for w in workers: + if getattr(w, "data", None): + try: + w.data.clear() + except OSError: + # zict backends can fail if their storage directory + # was already removed + pass + + return result + + async def async_fn_outer(): + async with _acheck_active_rpc(active_rpc_timeout=active_rpc_timeout): + if timeout: + return await asyncio.wait_for(async_fn(), timeout=timeout * 2) + return await async_fn() + + return _run_and_close_tornado(async_fn_outer) # Patch the signature so pytest can inject fixtures orig_sig = inspect.signature(func) @@ -1838,26 +1881,31 @@ def check_instances(): @contextmanager -def clean(threads=True, instances=True, timeout=1, processes=True): +def _reconfigure(): + reset_config() + + with dask.config.set( + { + "distributed.comm.timeouts.connect": "5s", + "distributed.admin.tick.interval": "500 ms", + } + ): + # Restore default logging levels + # XXX use pytest hooks/fixtures instead? + for name, level in logging_levels.items(): + logging.getLogger(name).setLevel(level) + + yield + + +@contextmanager +def clean(threads=True, instances=True, processes=True): + asyncio.set_event_loop(None) with check_thread_leak() if threads else nullcontext(): - with pristine_loop() as loop: - with check_process_leak(check=processes): - with check_instances() if instances else nullcontext(): - with check_active_rpc(loop, timeout): - reset_config() - - with dask.config.set( - { - "distributed.comm.timeouts.connect": "5s", - "distributed.admin.tick.interval": "500 ms", - } - ): - # Restore default logging levels - # XXX use pytest hooks/fixtures instead? - for name, level in logging_levels.items(): - logging.getLogger(name).setLevel(level) - - yield loop + with check_process_leak(check=processes): + with check_instances() if instances else nullcontext(): + with _reconfigure(): + yield @pytest.fixture