diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index 2f8f9da2213..7505a07e2b0 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -10,7 +10,7 @@ dependencies: - bokeh - click - cloudpickle - - coverage<6.3 # https://github.com/nedbat/coveragepy/issues/1310 + - coverage - dask # overridden by git tip below - filesystem-spec # overridden by git tip below - h5py diff --git a/continuous_integration/gpuci/axis.yaml b/continuous_integration/gpuci/axis.yaml index 41ddb56ec2d..a4773b2e3d3 100644 --- a/continuous_integration/gpuci/axis.yaml +++ b/continuous_integration/gpuci/axis.yaml @@ -8,6 +8,6 @@ LINUX_VER: - ubuntu18.04 RAPIDS_VER: -- "22.06" +- "22.08" excludes: diff --git a/distributed/_signals.py b/distributed/_signals.py index f25b7245c41..730e14d0fc7 100644 --- a/distributed/_signals.py +++ b/distributed/_signals.py @@ -25,4 +25,8 @@ def handle_signal(signum, frame): for sig in signals: old_handlers[sig] = signal.signal(sig, handle_signal) - await event.wait() + try: + await event.wait() + finally: + for sig in signals: + signal.signal(sig, old_handlers[sig]) diff --git a/distributed/chaos.py b/distributed/chaos.py index 1bcba1f7950..87255f1dd67 100644 --- a/distributed/chaos.py +++ b/distributed/chaos.py @@ -56,9 +56,7 @@ async def setup(self, worker): ) def graceful(self): - asyncio.create_task( - self.worker.close(report=False, nanny=False, executor_wait=False) - ) + asyncio.create_task(self.worker.close(nanny=False, executor_wait=False)) def sys_exit(self): sys.exit(0) diff --git a/distributed/cli/dask_scheduler.py b/distributed/cli/dask_scheduler.py index b976a5c3789..c6c43da01aa 100755 --- a/distributed/cli/dask_scheduler.py +++ b/distributed/cli/dask_scheduler.py @@ -9,7 +9,6 @@ import warnings import click -from tornado.ioloop import IOLoop from distributed import Scheduler from distributed._signals import wait_for_signals @@ -186,11 +185,9 @@ def del_pid_file(): resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard)) async def run(): - loop = IOLoop.current() logger.info("-" * 47) scheduler = Scheduler( - loop=loop, security=sec, host=host, port=port, diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index 1b01d1ad355..308ee86a011 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -26,6 +26,7 @@ assert_can_connect_from_everywhere_4_6, assert_can_connect_locally_4, popen, + wait_for_log_line, ) @@ -66,12 +67,8 @@ def test_dashboard(loop): pytest.importorskip("bokeh") with popen(["dask-scheduler"], flush_output=False) as proc: - for line in proc.stdout: - if b"dashboard at" in line: - dashboard_port = int(line.decode().split(":")[-1].strip()) - break - else: - assert False # pragma: nocover + line = wait_for_log_line(b"dashboard at", proc.stdout) + dashboard_port = int(line.decode().split(":")[-1].strip()) with Client(f"127.0.0.1:{Scheduler.default_port}", loop=loop): pass @@ -223,13 +220,9 @@ def test_dashboard_port_zero(loop): ["dask-scheduler", "--dashboard-address", ":0"], flush_output=False, ) as proc: - for line in proc.stdout: - if b"dashboard at" in line: - dashboard_port = int(line.decode().split(":")[-1].strip()) - assert dashboard_port != 0 - break - else: - assert False # pragma: nocover + line = wait_for_log_line(b"dashboard at", proc.stdout) + dashboard_port = int(line.decode().split(":")[-1].strip()) + assert dashboard_port != 0 PRELOAD_TEXT = """ @@ -413,7 +406,7 @@ def test_version_option(): @pytest.mark.slow -def test_idle_timeout(loop): +def test_idle_timeout(): start = time() runner = CliRunner() result = runner.invoke( @@ -424,6 +417,23 @@ def test_idle_timeout(loop): assert result.exit_code == 0 +@pytest.mark.slow +def test_restores_signal_handler(): + # another test could have altered the signal handler, so use a new function + # that both has sensible sigint behaviour *and* can be used as a sentinel + def raise_ki(): + raise KeyboardInterrupt + + original_handler = signal.signal(signal.SIGINT, raise_ki) + try: + CliRunner().invoke( + distributed.cli.dask_scheduler.main, ["--idle-timeout", "1s"] + ) + assert signal.getsignal(signal.SIGINT) is raise_ki + finally: + signal.signal(signal.SIGINT, original_handler) + + def test_multiple_workers_2(loop): text = """ def dask_setup(worker): diff --git a/distributed/cli/tests/test_dask_ssh.py b/distributed/cli/tests/test_dask_ssh.py index e087382fff8..d7b011737f7 100644 --- a/distributed/cli/tests/test_dask_ssh.py +++ b/distributed/cli/tests/test_dask_ssh.py @@ -4,7 +4,7 @@ from distributed import Client from distributed.cli.dask_ssh import main from distributed.compatibility import MACOS, WINDOWS -from distributed.utils_test import popen +from distributed.utils_test import popen, wait_for_log_line pytest.importorskip("paramiko") pytestmark = [ @@ -30,9 +30,7 @@ def test_ssh_cli_nprocs_renamed_to_nworkers(loop): # This interrupt is necessary for the cluster to place output into the stdout # and stderr pipes proc.send_signal(2) - assert any( - b"renamed to --nworkers" in proc.stdout.readline() for _ in range(15) - ) + wait_for_log_line(b"renamed to --nworkers", proc.stdout, max_lines=15) def test_ssh_cli_nworkers_with_nprocs_is_an_error(): @@ -40,6 +38,4 @@ def test_ssh_cli_nworkers_with_nprocs_is_an_error(): ["dask-ssh", "localhost", "--nprocs=2", "--nworkers=2"], flush_output=False, ) as proc: - assert any( - b"Both --nprocs and --nworkers" in proc.stdout.readline() for _ in range(15) - ) + wait_for_log_line(b"Both --nprocs and --nworkers", proc.stdout, max_lines=15) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index a67b4e241c7..ca8ed37aac5 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -18,7 +18,8 @@ from distributed.compatibility import LINUX, WINDOWS from distributed.deploy.utils import nprocesses_nthreads from distributed.metrics import time -from distributed.utils_test import gen_cluster, popen, requires_ipv6 +from distributed.utils import open_port +from distributed.utils_test import gen_cluster, popen, requires_ipv6, wait_for_log_line @pytest.mark.parametrize( @@ -245,9 +246,7 @@ async def test_nanny_worker_port_range_too_many_workers_raises(s): ], flush_output=False, ) as worker: - assert any( - b"Not enough ports in range" in worker.stdout.readline() for _ in range(100) - ) + wait_for_log_line(b"Not enough ports in range", worker.stdout, max_lines=100) @pytest.mark.slow @@ -281,26 +280,14 @@ async def test_reconnect_deprecated(c, s): ["dask-worker", s.address, "--reconnect"], flush_output=False, ) as worker: - for _ in range(10): - line = worker.stdout.readline() - print(line) - if b"`--reconnect` option has been removed" in line: - break - else: - raise AssertionError("Message not printed, see stdout") + wait_for_log_line(b"`--reconnect` option has been removed", worker.stdout) assert worker.wait() == 1 with popen( ["dask-worker", s.address, "--no-reconnect"], flush_output=False, ) as worker: - for _ in range(10): - line = worker.stdout.readline() - print(line) - if b"flag is deprecated, and will be removed" in line: - break - else: - raise AssertionError("Message not printed, see stdout") + wait_for_log_line(b"flag is deprecated, and will be removed", worker.stdout) await c.wait_for_workers(1) await c.shutdown() @@ -376,9 +363,7 @@ async def test_nworkers_requires_nanny(s): ["dask-worker", s.address, "--nworkers=2", "--no-nanny"], flush_output=False, ) as worker: - assert any( - b"Failed to launch worker" in worker.stdout.readline() for _ in range(15) - ) + wait_for_log_line(b"Failed to launch worker", worker.stdout, max_lines=15) @pytest.mark.slow @@ -418,9 +403,7 @@ async def test_worker_cli_nprocs_renamed_to_nworkers(c, s): flush_output=False, ) as worker: await c.wait_for_workers(2) - assert any( - b"renamed to --nworkers" in worker.stdout.readline() for _ in range(15) - ) + wait_for_log_line(b"renamed to --nworkers", worker.stdout, max_lines=15) @gen_cluster(nthreads=[]) @@ -429,10 +412,7 @@ async def test_worker_cli_nworkers_with_nprocs_is_an_error(s): ["dask-worker", s.address, "--nprocs=2", "--nworkers=2"], flush_output=False, ) as worker: - assert any( - b"Both --nprocs and --nworkers" in worker.stdout.readline() - for _ in range(15) - ) + wait_for_log_line(b"Both --nprocs and --nworkers", worker.stdout, max_lines=15) @pytest.mark.slow @@ -713,3 +693,37 @@ async def test_signal_handling(c, s, nanny, sig): assert "timed out" not in logs assert "error" not in logs assert "exception" not in logs + + +@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) +def test_error_during_startup(monkeypatch, nanny): + # see https://github.com/dask/distributed/issues/6320 + scheduler_port = str(open_port()) + scheduler_addr = f"tcp://127.0.0.1:{scheduler_port}" + + monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", scheduler_addr) + with popen( + [ + "dask-scheduler", + "--port", + scheduler_port, + ], + flush_output=False, + ) as scheduler: + start = time() + # Wait for the scheduler to be up + wait_for_log_line(b"Scheduler at", scheduler.stdout) + # Ensure this is not killed by pytest-timeout + if time() - start > 5: + raise TimeoutError("Scheduler failed to start in time.") + + with popen( + [ + "dask-worker", + scheduler_addr, + nanny, + "--worker-port", + scheduler_port, + ], + ) as worker: + assert worker.wait(5) == 1 diff --git a/distributed/client.py b/distributed/client.py index caf42dc19ad..af52607c9df 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1630,7 +1630,7 @@ async def _shutdown(self): else: with suppress(CommClosedError): self.status = "closing" - await self.scheduler.terminate(close_workers=True) + await self.scheduler.terminate() def shutdown(self): """Shut down the connected scheduler and workers diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index db57bc579e6..fae1e5b6d23 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -563,7 +563,7 @@ async def client_communicate(key, delay=0): @pytest.mark.gpu @gen_test() -async def test_ucx_client_server(): +async def test_ucx_client_server(ucx_loop): pytest.importorskip("distributed.comm.ucx") ucp = pytest.importorskip("ucp") diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index f4a5729826a..79fc298284b 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -22,25 +22,7 @@ HOST = "127.0.0.1" -def handle_exception(loop, context): - msg = context.get("exception", context["message"]) - print(msg) - - -# Let's make sure that UCX gets time to cancel -# progress tasks before closing the event loop. -@pytest.fixture() -def event_loop(scope="function"): - loop = asyncio.new_event_loop() - loop.set_exception_handler(handle_exception) - ucp.reset() - yield loop - ucp.reset() - loop.run_until_complete(asyncio.sleep(0)) - loop.close() - - -def test_registered(): +def test_registered(ucx_loop): assert "ucx" in backends backend = get_backend("ucx") assert isinstance(backend, ucx.UCXBackend) @@ -62,7 +44,7 @@ async def handle_comm(comm): @gen_test() -async def test_ping_pong(): +async def test_ping_pong(ucx_loop): com, serv_com = await get_comm_pair() msg = {"op": "ping"} await com.write(msg) @@ -80,7 +62,7 @@ async def test_ping_pong(): @gen_test() -async def test_comm_objs(): +async def test_comm_objs(ucx_loop): comm, serv_comm = await get_comm_pair() scheme, loc = parse_address(comm.peer_address) @@ -93,7 +75,7 @@ async def test_comm_objs(): @gen_test() -async def test_ucx_specific(): +async def test_ucx_specific(ucx_loop): """ Test concrete UCX API. """ @@ -147,7 +129,7 @@ async def client_communicate(key, delay=0): @gen_test() -async def test_ping_pong_data(): +async def test_ping_pong_data(ucx_loop): np = pytest.importorskip("numpy") data = np.ones((10, 10)) @@ -170,7 +152,7 @@ async def test_ping_pong_data(): @gen_test() -async def test_ucx_deserialize(): +async def test_ucx_deserialize(ucx_loop): # Note we see this error on some systems with this test: # `socket.gaierror: [Errno -5] No address associated with hostname` # This may be due to a system configuration issue. @@ -196,7 +178,7 @@ async def test_ucx_deserialize(): ], ) @gen_test() -async def test_ping_pong_cudf(g): +async def test_ping_pong_cudf(ucx_loop, g): # if this test appears after cupy an import error arises # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) @@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g): @pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)]) @gen_test() -async def test_ping_pong_cupy(shape): +async def test_ping_pong_cupy(ucx_loop, shape): cupy = pytest.importorskip("cupy") com, serv_com = await get_comm_pair() @@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape): @pytest.mark.slow @pytest.mark.parametrize("n", [int(1e9), int(2.5e9)]) @gen_test() -async def test_large_cupy(n, cleanup): +async def test_large_cupy(ucx_loop, n, cleanup): cupy = pytest.importorskip("cupy") com, serv_com = await get_comm_pair() @@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup): @gen_test() -async def test_ping_pong_numba(): +async def test_ping_pong_numba(ucx_loop): np = pytest.importorskip("numpy") numba = pytest.importorskip("numba") import numba.cuda @@ -276,7 +258,7 @@ async def test_ping_pong_numba(): @pytest.mark.parametrize("processes", [True, False]) @gen_test() -async def test_ucx_localcluster(processes, cleanup): +async def test_ucx_localcluster(ucx_loop, processes, cleanup): async with LocalCluster( protocol="ucx", host=HOST, @@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup): @pytest.mark.slow @gen_test(timeout=60) -async def test_stress(): +async def test_stress( + ucx_loop, +): da = pytest.importorskip("dask.array") chunksize = "10 MB" @@ -322,7 +306,9 @@ async def test_stress(): @gen_test() -async def test_simple(): +async def test_simple( + ucx_loop, +): async with LocalCluster(protocol="ucx", asynchronous=True) as cluster: async with Client(cluster, asynchronous=True) as client: assert cluster.scheduler_address.startswith("ucx://") @@ -330,7 +316,9 @@ async def test_simple(): @gen_test() -async def test_cuda_context(): +async def test_cuda_context( + ucx_loop, +): with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}): async with LocalCluster( protocol="ucx", n_workers=1, asynchronous=True @@ -344,7 +332,9 @@ async def test_cuda_context(): @gen_test() -async def test_transpose(): +async def test_transpose( + ucx_loop, +): da = pytest.importorskip("dask.array") async with LocalCluster(protocol="ucx", asynchronous=True) as cluster: @@ -358,7 +348,7 @@ async def test_transpose(): @pytest.mark.parametrize("port", [0, 1234]) @gen_test() -async def test_ucx_protocol(cleanup, port): +async def test_ucx_protocol(ucx_loop, cleanup, port): async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s: assert s.address.startswith("ucx://") @@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port): not hasattr(ucp.exceptions, "UCXUnreachable"), reason="Requires UCX-Py support for UCXUnreachable exception", ) -def test_ucx_unreachable(): - if ucp.get_ucx_version() > (1, 12, 0): - with pytest.raises(OSError, match="Timed out trying to connect to"): - Client("ucx://255.255.255.255:12345", timeout=1) - else: - with pytest.raises(ucp.exceptions.UCXError, match="Destination is unreachable"): - Client("ucx://255.255.255.255:12345", timeout=1) +@gen_test() +async def test_ucx_unreachable( + ucx_loop, +): + with pytest.raises(OSError, match="Timed out trying to connect to"): + await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True) diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index f1f3f08a3ab..baff89e6111 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -22,7 +22,7 @@ @gen_test() -async def test_ucx_config(cleanup): +async def test_ucx_config(ucx_loop, cleanup): ucx = { "nvlink": True, "infiniband": True, @@ -79,7 +79,7 @@ async def test_ucx_config(cleanup): reruns=10, reruns_delay=5, ) -def test_ucx_config_w_env_var(cleanup, loop): +def test_ucx_config_w_env_var(ucx_loop, cleanup, loop): env = os.environ.copy() env["DASK_RMM__POOL_SIZE"] = "1000.00 MB" diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 739cfe39e74..b4742a7c650 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -397,11 +397,7 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC init_once() try: ep = await ucp.create_endpoint(ip, port) - except (ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled,) + ( - getattr(ucp.exceptions, "UCXConnectionReset", ()), - getattr(ucp.exceptions, "UCXNotConnected", ()), - getattr(ucp.exceptions, "UCXUnreachable", ()), - ): # type: ignore + except ucp.exceptions.UCXBaseException: raise CommClosedError("Connection closed before handshake completed") return self.comm_class( ep, diff --git a/distributed/core.py b/distributed/core.py index d54592ef5a1..2a2a12521ff 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -187,7 +187,7 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._ongoing_coroutines = weakref.WeakSet() + self._ongoing_coroutines = set() self._event_finished = asyncio.Event() self.listeners = [] @@ -508,7 +508,7 @@ async def handle_comm(self, comm): await self try: - while True: + while not self.__stopped: try: msg = await comm.read() logger.debug("Message from %r: %s", address, msg) @@ -579,10 +579,17 @@ async def handle_comm(self, comm): result = handler(comm, **msg) else: result = handler(**msg) - if inspect.isawaitable(result): - result = asyncio.ensure_future(result) + if inspect.iscoroutine(result): + result = asyncio.create_task( + result, name=f"handle-comm-{address}-{op}" + ) self._ongoing_coroutines.add(result) + result.add_done_callback(self._ongoing_coroutines.remove) result = await result + elif inspect.isawaitable(result): + raise RuntimeError( + f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}" + ) except CommClosedError: if self.status == Status.running: logger.info("Lost connection to %r", address, exc_info=True) @@ -666,34 +673,36 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - @gen.coroutine - def close(self): + async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): pc.stop() if not self.__stopped: self.__stopped = True + _stops = set() for listener in self.listeners: future = listener.stop() if inspect.isawaitable(future): - yield future - for i in range(20): - # If there are still handlers running at this point, give them a - # second to finish gracefully themselves, otherwise... - if any(self._comms.values()): - yield asyncio.sleep(0.05) - else: - break - yield self.rpc.close() - yield [comm.close() for comm in list(self._comms)] # then forcefully close - for cb in self._ongoing_coroutines: - cb.cancel() - for i in range(10): - if all(c.cancelled() for c in self._ongoing_coroutines): - break - else: - yield asyncio.sleep(0.01) + _stops.add(future) + await asyncio.gather(*_stops) + + def _ongoing_tasks(): + return ( + t for t in self._ongoing_coroutines if t is not asyncio.current_task() + ) + + # TODO: Deal with exceptions + try: + # Give the handlers a bit of time to finish gracefully + await asyncio.wait_for( + asyncio.gather(*_ongoing_tasks(), return_exceptions=True), 1 + ) + except asyncio.TimeoutError: + # the timeout on gather should've cancelled all the tasks + await asyncio.gather(*_ongoing_tasks(), return_exceptions=True) + await self.rpc.close() + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) self._event_finished.set() diff --git a/distributed/dashboard/tests/test_components.py b/distributed/dashboard/tests/test_components.py index bc9f6c74849..f268c279b32 100644 --- a/distributed/dashboard/tests/test_components.py +++ b/distributed/dashboard/tests/test_components.py @@ -25,9 +25,9 @@ def test_basic(Component): async def test_profile_plot(c, s, a, b): p = ProfilePlot() assert not p.source.data["left"] - await c.gather(c.map(slowinc, range(10), delay=0.05)) - p.update(a.profile_recent) - assert len(p.source.data["left"]) >= 1 + while not len(p.source.data["left"]): + await c.submit(slowinc, 1, pure=False, delay=0.1) + p.update(a.profile_recent) @gen_cluster(client=True, clean_kwargs={"threads": False}) @@ -40,8 +40,8 @@ async def test_profile_time_plot(c, s, a, b): ap = ProfileTimePlot(a, doc=curdoc()) ap.trigger_update() - assert len(sp.source.data["left"]) <= 1 - assert len(ap.source.data["left"]) <= 1 + assert not len(sp.source.data["left"]) + assert not len(ap.source.data["left"]) await c.gather(c.map(slowinc, range(10), delay=0.05)) ap.trigger_update() diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 13e81f5f82d..d2d0da82ae1 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -427,10 +427,13 @@ def update(): cluster_repr_interval = parse_timedelta( dask.config.get("distributed.deploy.cluster-repr-interval", default="ms") ) - pc = PeriodicCallback(update, cluster_repr_interval * 1000) - self.periodic_callbacks["cluster-repr"] = pc - pc.start() + def install(): + pc = PeriodicCallback(update, cluster_repr_interval * 1000) + self.periodic_callbacks["cluster-repr"] = pc + pc.start() + + self.loop.add_callback(install) return tab def _repr_html_(self, cluster_status=None): diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 96a66d63aed..f24ce7b05d0 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -411,14 +411,18 @@ async def _close(self): if self.scheduler_comm: async with self._lock: with suppress(OSError): - await self.scheduler_comm.terminate(close_workers=True) + await self.scheduler_comm.terminate() await self.scheduler_comm.close_rpc() else: logger.warning("Cluster closed without starting up") await self.scheduler.close() for w in self._created: - assert w.status in {Status.closed, Status.failed}, w.status + assert w.status in { + Status.closing, + Status.closed, + Status.failed, + }, w.status if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index fca1fc4c550..781dc29a131 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -12,10 +12,9 @@ from dask.system import CPU_COUNT -from distributed import Client, Nanny, Worker, get_client +from distributed import Client, LocalCluster, Nanny, Worker, get_client from distributed.compatibility import LINUX from distributed.core import Status -from distributed.deploy.local import LocalCluster from distributed.deploy.utils_test import ClusterTest from distributed.metrics import time from distributed.system import MEMORY_LIMIT @@ -29,6 +28,7 @@ clean, gen_test, inc, + raises_with_cause, slowinc, tls_only_security, xfail_ssl_issue5601, @@ -582,6 +582,34 @@ def test_ipywidgets(loop): assert isinstance(box, ipywidgets.Widget) +def test_ipywidgets_loop(loop): + """ + Previously cluster._ipython_display_ attached the PeriodicCallback to the + currently running loop, See https://github.com/dask/distributed/pull/6444 + """ + ipywidgets = pytest.importorskip("ipywidgets") + + async def get_ioloop(cluster): + return cluster.periodic_callbacks["cluster-repr"].io_loop + + async def amain(): + # running synchronous code in an async context to setup a + # IOLoop.current() that's different from cluster.loop + with LocalCluster( + n_workers=0, + silence_logs=False, + loop=loop, + dashboard_address=":0", + processes=False, + ) as cluster: + cluster._ipython_display_() + assert cluster.sync(get_ioloop, cluster) is loop + box = cluster._cached_widget + assert isinstance(box, ipywidgets.Widget) + + asyncio.run(amain()) + + def test_no_ipywidgets(loop, monkeypatch): from unittest.mock import MagicMock @@ -1155,3 +1183,26 @@ async def test_connect_to_closed_cluster(): # Raises during init without actually connecting since we're not # awaiting anything Client(cluster, asynchronous=True) + + +class MyPlugin: + def setup(self, worker=None): + import my_nonexistent_library # noqa + + +@pytest.mark.slow +@gen_test( + clean_kwargs={ + # FIXME: This doesn't close the LoopRunner properly, leaving a thread around + "threads": False + } +) +async def test_localcluster_start_exception(): + with raises_with_cause(RuntimeError, None, ImportError, "my_nonexistent_library"): + async with LocalCluster( + n_workers=1, + threads_per_worker=1, + processes=True, + plugins={MyPlugin()}, + ): + return diff --git a/distributed/diagnostics/tests/test_cluster_dump_plugin.py b/distributed/diagnostics/tests/test_cluster_dump_plugin.py index 67ce815954d..b084e761603 100644 --- a/distributed/diagnostics/tests/test_cluster_dump_plugin.py +++ b/distributed/diagnostics/tests/test_cluster_dump_plugin.py @@ -14,7 +14,7 @@ async def test_cluster_dump_plugin(c, s, *workers, tmp_path): f2 = c.submit(inc, f1) assert (await f2) == 3 - await s.close(close_workers=True) + await s.close() dump = DumpArtefact.from_url(str(dump_file)) assert {f1.key, f2.key} == set(dump.scheduler_story(f1.key, f2.key).keys()) diff --git a/distributed/diagnostics/tests/test_scheduler_plugin.py b/distributed/diagnostics/tests/test_scheduler_plugin.py index 59511ba456a..5c433e7a1af 100644 --- a/distributed/diagnostics/tests/test_scheduler_plugin.py +++ b/distributed/diagnostics/tests/test_scheduler_plugin.py @@ -1,7 +1,9 @@ +import logging + import pytest from distributed import Scheduler, SchedulerPlugin, Worker, get_worker -from distributed.utils_test import gen_cluster, gen_test, inc +from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc @gen_cluster(client=True) @@ -198,3 +200,46 @@ def f(): await c.submit(f) assert ("foo", 123) in s._recorded_events + + +@gen_cluster(client=True) +async def test_register_plugin_on_scheduler(c, s, a, b): + class MyPlugin(SchedulerPlugin): + async def start(self, scheduler: Scheduler) -> None: + scheduler._foo = "bar" # type: ignore + + await s.register_scheduler_plugin(MyPlugin()) + + assert s._foo == "bar" + + +@gen_cluster(client=True) +async def test_closing_errors_ok(c, s, a, b, capsys): + class OK(SchedulerPlugin): + async def before_close(self): + print(123) + + async def close(self): + print(456) + + class Bad(SchedulerPlugin): + async def before_close(self): + raise Exception("BEFORE_CLOSE") + + async def close(self): + raise Exception("AFTER_CLOSE") + + await s.register_scheduler_plugin(OK()) + await s.register_scheduler_plugin(Bad()) + + with captured_logger(logging.getLogger("distributed.scheduler")) as logger: + await s.close() + + out, err = capsys.readouterr() + assert "123" in out + assert "456" in out + + text = logger.getvalue() + assert "BEFORE_CLOSE" in text + text = logger.getvalue() + assert "AFTER_CLOSE" in text diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 931c7438ad2..74d59addb35 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -53,7 +53,6 @@ distributed: - distributed.http.scheduler.prometheus - distributed.http.scheduler.info - distributed.http.scheduler.json - - distributed.http.scheduler.api - distributed.http.health - distributed.http.proxy - distributed.http.statics diff --git a/distributed/http/scheduler/tests/test_scheduler_http.py b/distributed/http/scheduler/tests/test_scheduler_http.py index 94f7563039e..e86f004b47e 100644 --- a/distributed/http/scheduler/tests/test_scheduler_http.py +++ b/distributed/http/scheduler/tests/test_scheduler_http.py @@ -10,11 +10,14 @@ from tornado.escape import url_escape from tornado.httpclient import AsyncHTTPClient, HTTPClientError +import dask.config from dask.sizeof import sizeof from distributed.utils import is_valid_xml from distributed.utils_test import gen_cluster, inc, slowinc +DEFAULT_ROUTES = dask.config.get("distributed.scheduler.http.routes") + @gen_cluster(client=True) async def test_connect(c, s, a, b): @@ -248,7 +251,20 @@ async def test_eventstream(c, s, a, b): ws_client.close() -@gen_cluster(client=True, clean_kwargs={"threads": False}) +def test_api_disabled_by_default(): + assert "distributed.http.scheduler.api" not in dask.config.get( + "distributed.scheduler.http.routes" + ) + + +@gen_cluster( + client=True, + clean_kwargs={"threads": False}, + config={ + "distributed.scheduler.http.routes": DEFAULT_ROUTES + + ["distributed.http.scheduler.api"] + }, +) async def test_api(c, s, a, b): async with aiohttp.ClientSession() as session: async with session.get( @@ -259,7 +275,14 @@ async def test_api(c, s, a, b): assert (await resp.text()) == "API V1" -@gen_cluster(client=True, clean_kwargs={"threads": False}) +@gen_cluster( + client=True, + clean_kwargs={"threads": False}, + config={ + "distributed.scheduler.http.routes": DEFAULT_ROUTES + + ["distributed.http.scheduler.api"] + }, +) async def test_retire_workers(c, s, a, b): async with aiohttp.ClientSession() as session: params = {"workers": [a.address, b.address]} @@ -273,7 +296,14 @@ async def test_retire_workers(c, s, a, b): assert len(retired_workers_info) == 2 -@gen_cluster(client=True, clean_kwargs={"threads": False}) +@gen_cluster( + client=True, + clean_kwargs={"threads": False}, + config={ + "distributed.scheduler.http.routes": DEFAULT_ROUTES + + ["distributed.http.scheduler.api"] + }, +) async def test_get_workers(c, s, a, b): async with aiohttp.ClientSession() as session: async with session.get( @@ -286,7 +316,14 @@ async def test_get_workers(c, s, a, b): assert set(workers_address) == {a.address, b.address} -@gen_cluster(client=True, clean_kwargs={"threads": False}) +@gen_cluster( + client=True, + clean_kwargs={"threads": False}, + config={ + "distributed.scheduler.http.routes": DEFAULT_ROUTES + + ["distributed.http.scheduler.api"] + }, +) async def test_adaptive_target(c, s, a, b): async with aiohttp.ClientSession() as session: async with session.get( diff --git a/distributed/nanny.py b/distributed/nanny.py index f8fb483c4c3..f23b78f8e58 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -16,7 +16,6 @@ from time import sleep as sync_sleep from typing import TYPE_CHECKING, ClassVar -import psutil from tornado import gen from tornado.ioloop import IOLoop @@ -470,7 +469,7 @@ async def plugin_remove(self, name=None): return {"status": "OK"} - async def restart(self, timeout=30, executor_wait=True): + async def restart(self, timeout=30): async def _(): if self.process is not None: await self.kill() @@ -486,19 +485,6 @@ async def _(): else: return "OK" - @property - def _psutil_process(self): - pid = self.process.process.pid - try: - self._psutil_process_obj - except AttributeError: - self._psutil_process_obj = psutil.Process(pid) - - if self._psutil_process_obj.pid != pid: - self._psutil_process_obj = psutil.Process(pid) - - return self._psutil_process_obj - def is_alive(self): return self.process is not None and self.process.is_alive() @@ -556,7 +542,7 @@ def close_gracefully(self): """ self.status = Status.closing_gracefully - async def close(self, comm=None, timeout=5, report=None): + async def close(self, timeout=5): """ Close the worker process, stop all comms. """ @@ -569,9 +555,8 @@ async def close(self, comm=None, timeout=5, report=None): self.status = Status.closing logger.info( - "Closing Nanny at %r. Report closure to scheduler: %s", + "Closing Nanny at %r.", self.address_safe, - report, ) for preload in self.preloads: @@ -594,9 +579,8 @@ async def close(self, comm=None, timeout=5, report=None): self.process = None await self.rpc.close() self.status = Status.closed - if comm: - await comm.write("OK") await super().close() + return "OK" async def _log_event(self, topic, msg): await self.scheduler.log_event( @@ -837,9 +821,7 @@ def _run( async def do_stop(timeout=5, executor_wait=True): try: await worker.close( - report=True, nanny=False, - safe=True, # TODO: Graceful or not? executor_wait=executor_wait, timeout=timeout, ) diff --git a/distributed/process.py b/distributed/process.py index 4597f276ef6..debbf025cc6 100644 --- a/distributed/process.py +++ b/distributed/process.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import logging +import multiprocessing import os import re import threading @@ -50,6 +53,8 @@ class AsyncProcess: All normally blocking methods are wrapped in Tornado coroutines. """ + _process: multiprocessing.Process + def __init__(self, loop=None, target=None, name=None, args=(), kwargs={}): if not callable(target): raise TypeError(f"`target` needs to be callable, not {type(target)!r}") @@ -175,7 +180,9 @@ def _run( target(*args, **kwargs) @classmethod - def _watch_message_queue(cls, selfref, process, loop, state, q, exit_future): + def _watch_message_queue( + cls, selfref, process: multiprocessing.Process, loop, state, q, exit_future + ): # As multiprocessing.Process is not thread-safe, we run all # blocking operations from this single loop and ship results # back to the caller when needed. @@ -204,7 +211,12 @@ def _start(): if op == "start": _call_and_set_future(loop, msg["future"], _start) elif op == "terminate": + # Send SIGTERM _call_and_set_future(loop, msg["future"], process.terminate) + elif op == "kill": + # Send SIGKILL + _call_and_set_future(loop, msg["future"], process.kill) + elif op == "stop": break else: @@ -240,17 +252,35 @@ def start(self): self._watch_q.put_nowait({"op": "start", "future": fut}) return fut - def terminate(self): - """ - Terminate the child process. + def terminate(self) -> asyncio.Future[None]: + """Terminate the child process. This method returns a future. + + See also + -------- + multiprocessing.Process.terminate """ self._check_closed() - fut = Future() + fut: Future[None] = Future() self._watch_q.put_nowait({"op": "terminate", "future": fut}) return fut + def kill(self) -> asyncio.Future[None]: + """Send SIGKILL to the child process. + On Windows, this is the same as terminate(). + + This method returns a future. + + See also + -------- + multiprocessing.Process.kill + """ + self._check_closed() + fut: Future[None] = Future() + self._watch_q.put_nowait({"op": "kill", "future": fut}) + return fut + async def join(self, timeout=None): """ Wait for the child process to exit. diff --git a/distributed/profile.py b/distributed/profile.py index a3a7bef94b9..5432637e5d5 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -44,6 +44,9 @@ from distributed.metrics import time from distributed.utils import color_of +#: This lock can be acquired to ensure that no instance of watch() is concurrently holding references to frames +lock = threading.Lock() + def identifier(frame: FrameType | None) -> str: """A string identifier from a frame @@ -314,18 +317,6 @@ def traverse(state, start, stop, height): } -_watch_running: set[int] = set() - - -def wait_profiler() -> None: - """Wait until a moment when no instances of watch() are sampling the frames. - You must call this function whenever you would otherwise expect an object to be - immediately released after it's descoped. - """ - while _watch_running: - sleep(0.0001) - - def _watch( thread_id: int, log: deque[tuple[float, dict[str, Any]]], # [(timestamp, output of create()), ...] @@ -337,24 +328,20 @@ def _watch( recent = create() last = time() - watch_id = threading.get_ident() while not stop(): - _watch_running.add(watch_id) - try: - if time() > last + cycle: + if time() > last + cycle: + recent = create() + with lock: log.append((time(), recent)) - recent = create() last = time() - try: - frame = sys._current_frames()[thread_id] - except KeyError: - return - - process(frame, None, recent, omit=omit) - del frame - finally: - _watch_running.remove(watch_id) + try: + frame = sys._current_frames()[thread_id] + except KeyError: + return + + process(frame, None, recent, omit=omit) + del frame sleep(interval) diff --git a/distributed/protocol/tests/test_pickle.py b/distributed/protocol/tests/test_pickle.py index 7fb486857c6..02310db6aea 100644 --- a/distributed/protocol/tests/test_pickle.py +++ b/distributed/protocol/tests/test_pickle.py @@ -5,7 +5,7 @@ import pytest -from distributed.profile import wait_profiler +from distributed import profile from distributed.protocol import deserialize, serialize from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads @@ -181,7 +181,7 @@ def funcs(): assert func3(1) == func(1) del func, func2, func3 - wait_profiler() - assert wr() is None - assert wr2() is None - assert wr3() is None + with profile.lock: + assert wr() is None + assert wr2() is None + assert wr3() is None diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 84562cee66f..048661ac2f2 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2882,6 +2882,14 @@ def __init__( transition_counter_max=False, **kwargs, ): + if loop is not None: + warnings.warn( + "the loop kwarg to Scheduler is deprecated", + DeprecationWarning, + stacklevel=2, + ) + + self.loop = IOLoop.current() self._setup_logging(logger) # Attributes @@ -2961,7 +2969,6 @@ def __init__( ) # Communication state - self.loop = loop or IOLoop.current() self.client_comms = {} self.stream_comms = {} self._worker_coroutines = [] @@ -3347,7 +3354,7 @@ def del_scheduler_file(): setproctitle(f"dask-scheduler [{self.address}]") return self - async def close(self, fast=False, close_workers=False): + async def close(self): """Send cleanup signal to all coroutines then wait until finished See Also @@ -3358,8 +3365,14 @@ async def close(self, fast=False, close_workers=False): await self.finished() return + async def log_errors(func): + try: + await func() + except Exception: + logger.exception("Plugin call failed during scheduler.close") + await asyncio.gather( - *[plugin.before_close() for plugin in list(self.plugins.values())] + *[log_errors(plugin.before_close) for plugin in list(self.plugins.values())] ) self.status = Status.closing @@ -3368,23 +3381,13 @@ async def close(self, fast=False, close_workers=False): setproctitle("dask-scheduler [closing]") for preload in self.preloads: - await preload.teardown() - - if close_workers: - await self.broadcast(msg={"op": "close_gracefully"}, nanny=True) - for worker in self.workers: - # Report would require the worker to unregister with the - # currently closing scheduler. This is not necessary and might - # delay shutdown of the worker unnecessarily - self.worker_send(worker, {"op": "close", "report": False}) - for i in range(20): # wait a second for send signals to clear - if self.workers: - await asyncio.sleep(0.05) - else: - break + try: + await preload.teardown() + except Exception as e: + logger.exception(e) await asyncio.gather( - *[plugin.close() for plugin in list(self.plugins.values())] + *[log_errors(plugin.close) for plugin in list(self.plugins.values())] ) for pc in self.periodic_callbacks.values(): @@ -3399,15 +3402,18 @@ async def close(self, fast=False, close_workers=False): logger.info("Scheduler closing all comms") futures = [] - for w, comm in list(self.stream_comms.items()): + for _, comm in list(self.stream_comms.items()): + # FIXME use `self.remove_worker()` instead after https://github.com/dask/distributed/issues/6390 if not comm.closed(): - comm.send({"op": "close", "report": False}) + # This closes the Worker and ensures that if a Nanny is around, + # it is closed as well + comm.send({"op": "close"}) comm.send({"op": "close-stream"}) + # ^ TODO remove? `Worker.close` will close the stream anyway. with suppress(AttributeError): futures.append(comm.close()) - for future in futures: # TODO: do all at once - await future + await asyncio.gather(*futures) for comm in self.client_comms.values(): comm.abort() @@ -3431,8 +3437,7 @@ async def close_worker(self, worker: str, stimulus_id: str, safe: bool = False): """ logger.info("Closing worker %s", worker) self.log_event(worker, {"action": "close-worker"}) - # FIXME: This does not handle nannies - self.worker_send(worker, {"op": "close", "report": False}) + self.worker_send(worker, {"op": "close"}) # TODO redundant with `remove_worker` await self.remove_worker(address=worker, safe=safe, stimulus_id=stimulus_id) ########### @@ -4152,7 +4157,9 @@ def stimulus_retry(self, keys, client=None): return tuple(seen) @log_errors - async def remove_worker(self, address, stimulus_id, safe=False, close=True): + async def remove_worker( + self, address: str, *, stimulus_id: str, safe: bool = False, close: bool = True + ) -> Literal["OK", "already-removed"]: """ Remove worker from cluster @@ -4161,7 +4168,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True): state. """ if self.status == Status.closed: - return + return "already-removed" address = self.coerce_address(address) @@ -4183,7 +4190,7 @@ async def remove_worker(self, address, stimulus_id, safe=False, close=True): logger.info("Remove worker %s", ws) if close: with suppress(AttributeError, CommClosedError): - self.stream_comms[address].send({"op": "close", "report": False}) + self.stream_comms[address].send({"op": "close"}) self.remove_resources(address) @@ -4772,7 +4779,6 @@ def handle_worker_status_change( worker_msgs: dict = {} self._transitions(recs, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - else: self.running.discard(ws) @@ -4865,7 +4871,8 @@ async def register_scheduler_plugin(self, plugin, name=None, idempotent=None): "arbitrary bytestrings using pickle via the " "'distributed.scheduler.pickle' configuration setting." ) - plugin = loads(plugin) + if not isinstance(plugin, SchedulerPlugin): + plugin = loads(plugin) if name is None: name = _get_plugin_name(plugin) @@ -5101,12 +5108,7 @@ async def restart(self, client=None, timeout=30): ] resps = All( - [ - nanny.restart( - close=True, timeout=timeout * 0.8, executor_wait=False - ) - for nanny in nannies - ] + [nanny.restart(close=True, timeout=timeout * 0.8) for nanny in nannies] ) try: resps = await asyncio.wait_for(resps, timeout) @@ -5999,6 +6001,8 @@ async def retire_workers( prev_status = ws.status ws.status = Status.closing_gracefully self.running.discard(ws) + # FIXME: We should send a message to the nanny first; + # eventually workers won't be able to close their own nannies. self.stream_comms[ws.address].send( { "op": "worker-status-change", diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index bd065deb4bf..582180a8869 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -635,8 +635,7 @@ class UsesCounter: def do_inc(self, ac): return ac.increment().result() - with cluster(nworkers=1) as (cl, _): - client = Client(cl["address"]) + with cluster(nworkers=1) as (cl, _), Client(cl["address"]) as client: ac = client.submit(Counter, actor=True).result() ac2 = client.submit(UsesCounter, actor=True).result() @@ -652,8 +651,7 @@ def do_inc(self, ac): # cannot expire return ac.increment().result(timeout=0.001) - with cluster(nworkers=1) as (cl, _): - client = Client(cl["address"]) + with cluster(nworkers=1) as (cl, _), Client(cl["address"]) as client: ac = client.submit(Counter, actor=True).result() ac2 = client.submit(UsesCounter, actor=True).result() @@ -667,8 +665,7 @@ class UsesCounter: def do_inc(self, ac): return get_client().sync(ac.increment) - with cluster(nworkers=1) as (cl, _): - client = Client(cl["address"]) + with cluster(nworkers=1) as (cl, _), Client(cl["address"]) as client: ac = client.submit(Counter, actor=True).result() ac2 = client.submit(UsesCounter, actor=True).result() @@ -701,8 +698,7 @@ def method(self): def prop(self): raise MyException - with cluster(nworkers=2) as (cl, w): - client = Client(cl["address"]) + with cluster(nworkers=2) as (cl, w), Client(cl["address"]) as client: ac = client.submit(Broken, actor=True).result() acfut = ac.method() with pytest.raises(MyException): diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 00300f07297..6cf54b987ac 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -193,7 +193,7 @@ async def test_terminate(): await proc.start() await proc.terminate() - await proc.join(timeout=30) + await proc.join() assert not proc.is_alive() assert proc.exitcode in (-signal.SIGTERM, 255) @@ -312,6 +312,26 @@ async def test_terminate_after_stop(): await proc.start() await asyncio.sleep(0.1) await proc.terminate() + await proc.join() + + +def kill_target(ev): + signal.signal(signal.SIGTERM, signal.SIG_IGN) + ev.set() + sleep(300) + + +@pytest.mark.skipif(WINDOWS, reason="Needs SIGKILL") +@gen_test() +async def test_kill(): + ev = mp_context.Event() + proc = AsyncProcess(target=kill_target, args=(ev,)) + await proc.start() + ev.wait() + await proc.kill() + await proc.join() + assert not proc.is_alive() + assert proc.exitcode in (-signal.SIGKILL, 255) def _worker_process(worker_ready, child_pipe): diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index a102a0b1a35..67fac061031 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -302,7 +302,7 @@ async def get_data(self, comm, *args, **kwargs): s.set_restrictions({fut1.key: [a.address, b.address]}) # It is removed, i.e. get_data is guaranteed to fail and f1 is scheduled # to be recomputed on B - await s.remove_worker(a.address, "foo", close=False, safe=True) + await s.remove_worker(a.address, stimulus_id="foo", close=False, safe=True) while not b.tasks[fut1.key].state == "resumed": await asyncio.sleep(0.01) @@ -440,7 +440,7 @@ async def get_data(self, comm, *args, **kwargs): f3.key: {w2.address}, } ) - await s.remove_worker(w1.address, "stim-id") + await s.remove_worker(w1.address, stimulus_id="stim-id") await wait_for_state(f3.key, "resumed", w2) assert_story( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 32a92d52716..b82f36e27f5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -69,7 +69,6 @@ from distributed.compatibility import LINUX, WINDOWS from distributed.core import Server, Status from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler from distributed.sizeof import sizeof from distributed.utils import is_valid_xml, mp_context, sync, tmp_text @@ -678,8 +677,8 @@ def test_no_future_references(c): futures = c.map(inc, range(10)) ws.update(futures) del futures - wait_profiler() - assert not list(ws) + with profile.lock: + assert not list(ws) def test_get_sync_optimize_graph_passes_through(c): @@ -811,9 +810,9 @@ async def test_recompute_released_key(c, s, a, b): result1 = await x xkey = x.key del x - wait_profiler() - await asyncio.sleep(0) - assert c.refcount[xkey] == 0 + with profile.lock: + await asyncio.sleep(0) + assert c.refcount[xkey] == 0 # 1 second batching needs a second action to trigger while xkey in s.tasks and s.tasks[xkey].who_has or xkey in a.data or xkey in b.data: @@ -3483,10 +3482,9 @@ async def test_Client_clears_references_after_restart(c, s, a, b): key = x.key del x - wait_profiler() - await asyncio.sleep(0) - - assert key not in c.refcount + with profile.lock: + await asyncio.sleep(0) + assert key not in c.refcount @gen_cluster(Worker=Nanny, client=True) @@ -3655,7 +3653,7 @@ async def hard_stop(s): except CancelledError: break - await w.close(report=False) + await w.close() await c._close(fast=True) @@ -6134,7 +6132,7 @@ async def test_shutdown(): await c.shutdown() assert s.status == Status.closed - assert w.status == Status.closed + assert w.status in {Status.closed, Status.closing} @gen_test() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 74e38f8671b..7422181d024 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -978,22 +978,22 @@ async def long_handler(comm): await asyncio.sleep(0.01) +@pytest.mark.parametrize("close_via_rpc", [True, False]) @gen_test() -async def test_close_fast_without_active_handlers(): - async def very_fast(comm): - return "done" +async def test_close_fast_without_active_handlers(close_via_rpc): - server = await Server({"do_stuff": very_fast}) + server = await Server({}) + server.handlers["terminate"] = server.close await server.listen(0) assert server._comms == {} - comm = await connect(server.address) - await comm.write({"op": "do_stuff"}) - while not server._comms: - await asyncio.sleep(0.05) - fut = server.close() - - await asyncio.wait_for(fut, 0.1) + if not close_via_rpc: + fut = server.close() + await asyncio.wait_for(fut, 0.5) + else: + async with rpc(server.address) as _rpc: + fut = _rpc.terminate(reply=False) + await asyncio.wait_for(fut, 0.5) @gen_test() @@ -1010,13 +1010,14 @@ async def long_handler(comm, delay=10): await comm.write({"op": "wait"}) while not server._comms: await asyncio.sleep(0.05) - fut = server.close() + task = asyncio.create_task(server.close()) + wait_for_close = asyncio.Event() + task.add_done_callback(lambda _: wait_for_close.set) # since the handler is running for a while, the close will not immediately # go through. We'll give the comm about a second to close itself with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(fut, 0.5) - await comm.close() - await server.close() + await asyncio.wait_for(wait_for_close.wait(), 0.5) + await task def test_expects_comm(): diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index c9f1a91a63e..9127651cfd6 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -12,10 +12,10 @@ import dask +from distributed import profile from distributed.compatibility import WINDOWS from distributed.diskutils import WorkSpace from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.utils import mp_context from distributed.utils_test import captured_logger @@ -53,8 +53,8 @@ def test_workdir_simple(tmpdir): a.release() assert_contents(["bb", "bb.dirlock"]) del b - wait_profiler() - gc.collect() + with profile.lock: + gc.collect() assert_contents([]) # Generated temporary name with a prefix @@ -89,12 +89,12 @@ def test_two_workspaces_in_same_directory(tmpdir): del ws del b - wait_profiler() - gc.collect() + with profile.lock: + gc.collect() assert_contents(["aa", "aa.dirlock"], trials=5) del a - wait_profiler() - gc.collect() + with profile.lock: + gc.collect() assert_contents([], trials=5) @@ -188,8 +188,8 @@ def test_locking_disabled(tmpdir): a.release() assert_contents(["bb"]) del b - wait_profiler() - gc.collect() + with profile.lock: + gc.collect() assert_contents([]) lock_file.assert_not_called() diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 371eb8ae54f..dfdfa2c0031 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -10,11 +10,10 @@ from dask import delayed -from distributed import Client, Nanny, wait +from distributed import Client, Nanny, profile, wait from distributed.comm import CommClosedError from distributed.compatibility import MACOS from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.utils import CancelledError, sync from distributed.utils_test import ( captured_logger, @@ -262,7 +261,10 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): await c.restart() y = c.submit(inc, 1) del x - wait_profiler() + + # Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected + with profile.lock: + pass await asyncio.sleep(0.1) await y diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index bef0cc04010..c3588d625b9 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -19,12 +19,11 @@ import dask from dask.utils import tmpfile -from distributed import Nanny, Scheduler, Worker, rpc, wait, worker +from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker from distributed.compatibility import LINUX, WINDOWS from distributed.core import CommClosedError, Status from distributed.diagnostics import SchedulerPlugin from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.protocol.pickle import dumps from distributed.utils import TimeoutError, parse_ports from distributed.utils_test import ( @@ -170,8 +169,8 @@ async def test_num_fds(s): # Warm up async with Nanny(s.address): pass - wait_profiler() - gc.collect() + with profile.lock: + gc.collect() before = proc.num_fds() @@ -404,7 +403,7 @@ def remove_worker(self, **kwargs): @gen_cluster(client=True, nthreads=[]) -async def test_nanny_closes_cleanly_2(c, s): +async def test_nanny_closes_cleanly_if_worker_is_terminated(c, s): async with Nanny(s.address) as n: async with c.rpc(n.worker_address) as w: IOLoop.current().add_callback(w.terminate) @@ -460,7 +459,7 @@ def raise_err(): @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) @gen_test() -async def test_nanny_closed_by_keyboard_interrupt(protocol): +async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): if protocol == "ucx": # Skip if UCX isn't available pytest.importorskip("ucp") @@ -568,3 +567,19 @@ async def test_restart_memory(c, s, n): while not s.workers: await asyncio.sleep(0.1) + + +@gen_cluster(Worker=Nanny, nthreads=[("", 1)]) +async def test_scheduler_crash_doesnt_restart(s, a): + # Simulate a scheduler crash by disconnecting it first + # (`s.close()` would tell workers to cleanly shut down) + bcomm = next(iter(s.stream_comms.values())) + bcomm.abort() + await s.close() + + while a.status != Status.closing_gracefully: + await asyncio.sleep(0.01) + + await a.finished() + assert a.status == Status.closed + assert a.process is None diff --git a/distributed/tests/test_preload.py b/distributed/tests/test_preload.py index 0bb73dd0d15..04bec1cb951 100644 --- a/distributed/tests/test_preload.py +++ b/distributed/tests/test_preload.py @@ -1,3 +1,4 @@ +import logging import os import re import shutil @@ -281,6 +282,23 @@ def dask_setup(client, value): assert c.foo == value +@gen_test() +async def test_teardown_failure_doesnt_crash_scheduler(): + text = """ +def dask_teardown(worker): + raise Exception(123) +""" + + with captured_logger(logging.getLogger("distributed.scheduler")) as s_logger: + with captured_logger(logging.getLogger("distributed.worker")) as w_logger: + async with Scheduler(dashboard_address=":0", preload=text) as s: + async with Worker(s.address, preload=[text]) as w: + pass + + assert "123" in s_logger.getvalue() + assert "123" in w_logger.getvalue() + + @gen_cluster(nthreads=[]) async def test_client_preload_config_click(s): text = dedent( diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index 92fb6c1cfee..1d417cb19c5 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -18,6 +18,7 @@ info_frame, ll_get_stack, llprocess, + lock, merge, plot_data, process, @@ -184,27 +185,68 @@ def test_identifier(): def test_watch(): + stop_called = threading.Event() + watch_thread = None start = time() def stop(): + if not stop_called.is_set(): # Run setup code + nonlocal watch_thread + nonlocal start + watch_thread = threading.current_thread() + start = time() + stop_called.set() return time() > start + 0.500 + log = watch(interval="10ms", cycle="50ms", stop=stop) + + stop_called.wait(2) + sleep(0.5) + assert 1 < len(log) < 10 + watch_thread.join(2) + + +def test_watch_requires_lock_to_run(): + start = time() + + stop_profiling_called = threading.Event() + profiling_thread = None + + def stop_profiling(): + if not stop_profiling_called.is_set(): # Run setup code + nonlocal profiling_thread + nonlocal start + profiling_thread = threading.current_thread() + start = time() + stop_profiling_called.set() + return time() > start + 0.500 + + release_lock = threading.Event() + + def block_lock(): + with lock: + release_lock.wait() + start_threads = threading.active_count() - log = watch(interval="10ms", cycle="50ms", stop=stop) + # Block the lock over the entire duration of watch + blocking_thread = threading.Thread(target=block_lock, name="Block Lock") + blocking_thread.daemon = True + blocking_thread.start() + + log = watch(interval="10ms", cycle="50ms", stop=stop_profiling) start = time() # wait until thread starts up - while threading.active_count() <= start_threads: + while threading.active_count() < start_threads + 2: assert time() < start + 2 sleep(0.01) sleep(0.5) - assert 1 < len(log) < 10 + assert len(log) == 0 + release_lock.set() - start = time() - while threading.active_count() > start_threads: - assert time() < start + 2 - sleep(0.01) + profiling_thread.join(2) + blocking_thread.join(2) @dataclasses.dataclass(frozen=True) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2b455836bc5..e304a3959ec 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -15,6 +15,7 @@ import psutil import pytest from tlz import concat, first, merge, valmap +from tornado.ioloop import IOLoop import dask from dask import delayed @@ -779,8 +780,14 @@ async def test_update_graph_culls(s, a, b): def test_io_loop(loop): - s = Scheduler(loop=loop, dashboard_address=":0", validate=True) - assert s.io_loop is loop + async def main(): + with pytest.warns( + DeprecationWarning, match=r"the loop kwarg to Scheduler is deprecated" + ): + s = Scheduler(loop=loop, dashboard_address=":0", validate=True) + assert s.io_loop is IOLoop.current() + + asyncio.run(main()) @gen_cluster(client=True) @@ -1758,10 +1765,12 @@ async def test_result_type(c, s, a, b): @gen_cluster() -async def test_close_workers(s, a, b): - await s.close(close_workers=True) - assert a.status == Status.closed - assert b.status == Status.closed +async def test_close_workers(s, *workers): + await s.close() + + for w in workers: + if not w.status == Status.closed: + await asyncio.sleep(0.1) @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") @@ -2591,7 +2600,7 @@ async def test_memory_is_none(c, s): @gen_cluster() async def test_close_scheduler__close_workers_Worker(s, a, b): with captured_logger("distributed.comm", level=logging.DEBUG) as log: - await s.close(close_workers=True) + await s.close() while not a.status == Status.closed: await asyncio.sleep(0.05) log = log.getvalue() @@ -2601,7 +2610,7 @@ async def test_close_scheduler__close_workers_Worker(s, a, b): @gen_cluster(Worker=Nanny) async def test_close_scheduler__close_workers_Nanny(s, a, b): with captured_logger("distributed.comm", level=logging.DEBUG) as log: - await s.close(close_workers=True) + await s.close() while not a.status == Status.closed: await asyncio.sleep(0.05) log = log.getvalue() @@ -2729,6 +2738,14 @@ async def test_rebalance_raises_missing_data3(c, s, a, b, explicit): futures = await c.scatter(range(100), workers=[a.address]) if explicit: + pytest.xfail( + reason="""Freeing keys and gathering data is using different + channels (stream vs explicit RPC). Therefore, the + partial-fail is very timing sensitive and subject to a race + condition. This test assumes that the data is freed before + the rebalance get_data requests come in but merely deleting + the futures is not sufficient to guarantee this""" + ) keys = [f.key for f in futures] del futures out = await s.rebalance(keys=keys) diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index 4bd08875303..7c14013c37e 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -8,8 +8,8 @@ from dask.sizeof import sizeof +from distributed import profile from distributed.compatibility import WINDOWS -from distributed.profile import wait_profiler from distributed.protocol import serialize_bytelist from distributed.spill import SpillBuffer, has_zict_210, has_zict_220 from distributed.utils_test import captured_logger @@ -338,7 +338,10 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size): # the same id as a deleted one id_x = x.id del x - wait_profiler() + + # Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected + with profile.lock: + pass if size < 100: buf["y"] diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 9fc2420d6ba..4b68f927596 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -12,12 +12,11 @@ import dask -from distributed import Event, Lock, Nanny, Worker, wait, worker_client +from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client from distributed.compatibility import LINUX from distributed.config import config from distributed.core import Status from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.scheduler import key_split from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( @@ -948,8 +947,8 @@ class Foo: assert not s.tasks - wait_profiler() - assert not list(ws) + with profile.lock: + assert not list(ws) @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 961bd6a9164..bc1dea2e81c 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -259,8 +259,12 @@ def test_seek_delimiter_endline(): memoryview(bytearray(b"1")), array("B", b"1"), array("I", range(5)), + memoryview(b"123456")[1:-1], memoryview(b"123456")[::2], + memoryview(array("I", range(5)))[1:-1], + memoryview(array("I", range(5)))[::2], memoryview(b"123456").cast("B", (2, 3)), + memoryview(b"0123456789").cast("B", (5, 2))[1:-1], memoryview(b"0123456789").cast("B", (5, 2))[::2], ], ) @@ -273,7 +277,6 @@ def test_ensure_memoryview(data): assert result.format == "B" assert result == bytes(data_mv) if data_mv.nbytes and data_mv.contiguous: - assert id(result.obj) == id(data_mv.obj) assert result.readonly == data_mv.readonly if isinstance(data, memoryview): if data.ndim == 1 and data.format == "B": diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index c99e05532b8..3116fb3a8c8 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -16,13 +16,14 @@ from distributed import Client, Nanny, Scheduler, Worker, config, default_client from distributed.compatibility import WINDOWS -from distributed.core import Server, rpc +from distributed.core import Server, Status, rpc from distributed.metrics import time from distributed.utils import mp_context from distributed.utils_test import ( _LockedCommPool, _UnhashableCallable, assert_story, + captured_logger, check_process_leak, cluster, dump_cluster_state, @@ -33,7 +34,7 @@ raises_with_cause, tls_only_security, ) -from distributed.worker import InvalidTransition +from distributed.worker import InvalidTransition, fail_hard def test_bare_cluster(loop): @@ -731,15 +732,43 @@ def test_raises_with_cause(): raise RuntimeError("exception") from ValueError("cause") -def test_worker_fail_hard(capsys): - @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) - async def test_fail_hard(c, s, a): - with pytest.raises(Exception): - await a.gather_dep( - worker="abcd", to_gather=["x"], total_nbytes=0, stimulus_id="foo" - ) +@pytest.mark.parametrize("sync", [True, False]) +def test_fail_hard(sync): + """@fail_hard is a last resort when error handling for everything that we foresaw + could possibly go wrong failed. + Instead of trying to force a crash here, we'll write custom methods which do crash. + """ - with pytest.raises(Exception) as info: - test_fail_hard() + class CustomError(Exception): + pass + + class FailWorker(Worker): + @fail_hard + def fail_sync(self): + raise CustomError() + + @fail_hard + async def fail_async(self): + raise CustomError() + + test_done = False + + @gen_cluster(nthreads=[]) + async def test(s): + nonlocal test_done + with captured_logger("distributed.worker") as logger: + async with FailWorker(s.address) as a: + with pytest.raises(CustomError): + if sync: + a.fail_sync() + else: + await a.fail_async() + + while a.status != Status.closed: + await asyncio.sleep(0.01) + + test_done = True - assert "abcd" in str(info.value) + with pytest.raises(CustomError): + test() + assert test_done diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index eebce39d57e..8032d81a071 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -34,6 +34,7 @@ default_client, get_client, get_worker, + profile, wait, ) from distributed.comm.registry import backends @@ -42,7 +43,6 @@ from distributed.diagnostics import nvml from distributed.diagnostics.plugin import PipInstall from distributed.metrics import time -from distributed.profile import wait_profiler from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils_test import ( @@ -192,7 +192,7 @@ def g(): assert result == 123 await c.close() - await s.close(close_workers=True) + await s.close() assert not os.path.exists(os.path.join(a.local_directory, "foobar.py")) @@ -1384,7 +1384,7 @@ async def test_interface_async(Worker): @pytest.mark.gpu @pytest.mark.parametrize("Worker", [Worker, Nanny]) @gen_test() -async def test_protocol_from_scheduler_address(Worker): +async def test_protocol_from_scheduler_address(ucx_loop, Worker): pytest.importorskip("ucp") async with Scheduler(protocol="ucx", dashboard_address=":0") as s: @@ -1741,7 +1741,7 @@ async def close(self): ) as wlogger, captured_logger( "distributed.scheduler", level=logging.WARNING ) as slogger: - await s.remove_worker(a.address, "foo") + await s.remove_worker(a.address, stimulus_id="foo") assert not s.workers # Wait until the close signal reaches the worker and it starts shutting down. @@ -1851,8 +1851,8 @@ class C: del f while "f" in a.data: await asyncio.sleep(0.01) - wait_profiler() - assert ref() is None + with profile.lock: + assert ref() is None story = a.stimulus_story("f", "f2") assert {ev.key for ev in story} == {"f", "f2"} @@ -2569,6 +2569,24 @@ def __call__(self, *args, **kwargs): threadpool.shutdown() +@gen_cluster(client=True) +async def test_run_spec_deserialize_fail(c, s, a, b): + class F: + def __call__(self): + pass + + def __reduce__(self): + return lambda: 1 / 0, () + + with captured_logger("distributed.worker") as logger: + fut = c.submit(F()) + assert isinstance(await fut.exception(), ZeroDivisionError) + + logvalue = logger.getvalue() + assert "Could not deserialize task" in logvalue + assert "return lambda: 1 / 0, ()" in logvalue + + @gen_cluster(client=True) async def test_gather_dep_exception_one_task(c, s, a, b): """Ensure an exception in a single task does not tear down an entire batch of gather_dep @@ -2962,7 +2980,7 @@ async def test_missing_released_zombie_tasks(c, s, a, b): while key not in b.tasks or b.tasks[key].state != "fetch": await asyncio.sleep(0.01) - await a.close(report=False) + await a.close() del f1, f2 diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 1641927352d..84e71763d09 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -5,18 +5,18 @@ import logging import os import signal -import sys import threading from collections import Counter, UserDict from time import sleep +import psutil import pytest import dask.config import distributed.system from distributed import Client, Event, KilledWorker, Nanny, Scheduler, Worker, wait -from distributed.compatibility import MACOS +from distributed.compatibility import MACOS, WINDOWS from distributed.core import Status from distributed.metrics import monotonic from distributed.spill import has_zict_210 @@ -684,7 +684,7 @@ async def test_manual_evict_proto(c, s, a): await asyncio.sleep(0.01) -async def leak_until_restart(c: Client, s: Scheduler, a: Nanny) -> None: +async def leak_until_restart(c: Client, s: Scheduler) -> None: s.allowed_failures = 0 def leak(): @@ -693,32 +693,25 @@ def leak(): L.append(b"0" * 5_000_000) sleep(0.01) - assert a.process - assert a.process.process - pid = a.process.pid - addr = a.worker_address - with captured_logger(logging.getLogger("distributed.worker_memory")) as logger: - future = c.submit(leak, key="leak") - while ( - not a.process - or not a.process.process - or a.process.pid == pid - or a.worker_address == addr - ): - await asyncio.sleep(0.01) + (addr,) = s.workers + pid = (await c.run(os.getpid))[addr] - # Test that the restarting message happened only once; - # see test_slow_terminate below. - assert logger.getvalue() == ( - f"Worker {addr} (pid={pid}) exceeded 95% memory budget. Restarting...\n" - ) + future = c.submit(leak, key="leak") + + # Wait until the worker is restarted + while len(s.workers) != 1 or set(s.workers) == {addr}: + await asyncio.sleep(0.01) + + # Test that the process has been properly waited for and not just left there + with pytest.raises(psutil.NoSuchProcess): + psutil.Process(pid) with pytest.raises(KilledWorker): await future assert s.tasks["leak"].suspicious == 1 - assert await c.run(lambda dask_worker: "leak" in dask_worker.tasks) == { - a.worker_address: False - } + assert not any( + (await c.run(lambda dask_worker: "leak" in dask_worker.tasks)).values() + ) future.release() while "leak" in s.tasks: await asyncio.sleep(0.01) @@ -733,10 +726,17 @@ def leak(): config={"distributed.worker.memory.monitor-interval": "10ms"}, ) async def test_nanny_terminate(c, s, a): - await leak_until_restart(c, s, a) + await leak_until_restart(c, s) @pytest.mark.slow +@pytest.mark.parametrize( + "ignore_sigterm", + [ + False, + pytest.param(True, marks=pytest.mark.skipif(WINDOWS, reason="Needs SIGKILL")), + ], +) @gen_cluster( nthreads=[("", 1)], client=True, @@ -744,50 +744,31 @@ async def test_nanny_terminate(c, s, a): worker_kwargs={"memory_limit": "400 MiB"}, config={"distributed.worker.memory.monitor-interval": "10ms"}, ) -async def test_disk_cleanup_on_terminate(c, s, a): - """Test that the spilled data on disk is cleaned up when the nanny kills the worker""" +async def test_disk_cleanup_on_terminate(c, s, a, ignore_sigterm): + """Test that the spilled data on disk is cleaned up when the nanny kills the worker. + + Unlike in a regular worker shutdown, where the worker deletes its own spill + directory, the cleanup in case of termination from the monitor is performed by the + nanny. + + The worker may be slow to accept SIGTERM, for whatever reason. + At the next iteration of the memory manager, if the process is still alive, the + nanny sends SIGKILL. + """ + if ignore_sigterm: + await c.run(signal.signal, signal.SIGTERM, signal.SIG_IGN) + fut = c.submit(inc, 1, key="myspill") await wait(fut) await c.run(lambda dask_worker: dask_worker.data.evict()) - glob_out = await c.run( lambda dask_worker: glob.glob(dask_worker.local_directory + "/**/myspill") ) - spill_file = glob_out[a.worker_address][0] - assert os.path.exists(spill_file) - await leak_until_restart(c, s, a) - assert not os.path.exists(spill_file) - - -@pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - nthreads=[("", 1)], - worker_kwargs={"memory_limit": "400 MiB"}, - config={"distributed.worker.memory.monitor-interval": "10ms"}, -) -async def test_slow_terminate(c, s, a): - """A worker is slow to accept SIGTERM, e.g. because the - distributed.diskutils.WorkDir teardown is deleting tens of GB worth of spilled data. - """ + spill_fname = next(iter(glob_out.values()))[0] + assert os.path.exists(spill_fname) - def install_slow_sigterm_handler(): - def cb(signo, frame): - # If something sends SIGTERM while the previous SIGTERM handler is running, - # you will eventually get RecursionError. - print(f"Received signal {signo}") - sleep(0.2) # Longer than monitor-interval - print("Leaving handler") - sys.exit(0) - - signal.signal(signal.SIGTERM, cb) - - await c.run(install_slow_sigterm_handler) - # Test that SIGTERM is only sent once - await leak_until_restart(c, s, a) - # Test that SIGTERM can be sent again after the worker restarts - await leak_until_restart(c, s, a) + await leak_until_restart(c, s) + assert not os.path.exists(spill_fname) @gen_cluster( diff --git a/distributed/utils.py b/distributed/utils.py index e24e45b0c86..ad6745a09fd 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1029,9 +1029,8 @@ def ensure_memoryview(obj): elif mv.ndim != 1 or mv.format != "B": # Perform zero-copy reshape & cast # Use `PickleBuffer.raw()` as `memoryview.cast()` fails with F-order - # Pass `mv.obj` so the created `memoryview` has that as its `obj` # xref: https://github.com/python/cpython/issues/91484 - return PickleBuffer(mv.obj).raw() + return PickleBuffer(mv).raw() else: # Return `memoryview` as it already meets requirements return mv diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 8cce6d44494..edd74bef17e 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -19,18 +19,18 @@ import sys import tempfile import threading +import warnings import weakref from collections import defaultdict from collections.abc import Callable from contextlib import contextmanager, nullcontext, suppress from itertools import count from time import sleep -from typing import Any, Generator, Literal +from typing import IO, Any, Generator, Iterator, Literal import pytest import yaml from tlz import assoc, memoize, merge -from tornado import gen from tornado.ioloop import IOLoop import dask @@ -132,11 +132,11 @@ def invalid_python_script(tmpdir_factory): async def cleanup_global_workers(): for worker in Worker._instances: - await worker.close(report=False, executor_wait=False) + await worker.close(executor_wait=False) @pytest.fixture -def loop(): +def loop(cleanup): with check_instances(): with pristine_loop() as loop: # Monkey-patch IOLoop.start to wait for loop stop @@ -169,17 +169,35 @@ def start(): @pytest.fixture -def loop_in_thread(): - with pristine_loop() as loop: - thread = threading.Thread(target=loop.start, name="test IOLoop") - thread.daemon = True - thread.start() - loop_started = threading.Event() - loop.add_callback(loop_started.set) - loop_started.wait() - yield loop - loop.add_callback(loop.stop) - thread.join(timeout=5) +def loop_in_thread(cleanup): + loop_started = concurrent.futures.Future() + with concurrent.futures.ThreadPoolExecutor( + 1, thread_name_prefix="test IOLoop" + ) as tpe: + + async def run(): + io_loop = IOLoop.current() + stop_event = asyncio.Event() + loop_started.set_result((io_loop, stop_event)) + await stop_event.wait() + + # run asyncio.run in a thread and collect exceptions from *either* + # the loop failing to start, or failing to close + ran = tpe.submit(_run_and_close_tornado, run) + for f in concurrent.futures.as_completed((loop_started, ran)): + if f is loop_started: + io_loop, stop_event = loop_started.result() + try: + yield io_loop + finally: + io_loop.add_callback(stop_event.set) + + elif f is ran: + # if this is the first iteration the loop failed to start + # if it's the second iteration the loop has finished or + # the loop failed to close and we need to raise the exception + ran.result() + return @pytest.fixture @@ -442,29 +460,37 @@ async def background_read(): return msg +def _run_and_close_tornado(async_fn, /, *args, **kwargs): + tornado_loop = None + + async def inner_fn(): + nonlocal tornado_loop + tornado_loop = IOLoop.current() + return await async_fn(*args, **kwargs) + + try: + return asyncio.run(inner_fn()) + finally: + tornado_loop.close(all_fds=True) + + def run_scheduler(q, nputs, config, port=0, **kwargs): with dask.config.set(config): - # On Python 2.7 and Unix, fork() is used to spawn child processes, - # so avoid inheriting the parent's IO loop. - with pristine_loop() as loop: - - async def _(): - try: - scheduler = await Scheduler( - validate=True, host="127.0.0.1", port=port, **kwargs - ) - except Exception as exc: - for i in range(nputs): - q.put(exc) - else: - for i in range(nputs): - q.put(scheduler.address) - await scheduler.finished() + async def _(): try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) + scheduler = await Scheduler( + validate=True, host="127.0.0.1", port=port, **kwargs + ) + except Exception as exc: + for i in range(nputs): + q.put(exc) + else: + for i in range(nputs): + q.put(scheduler.address) + await scheduler.finished() + + _run_and_close_tornado(_) def run_worker(q, scheduler_q, config, **kwargs): @@ -473,37 +499,12 @@ def run_worker(q, scheduler_q, config, **kwargs): reset_logger_locks() with log_errors(): - with pristine_loop() as loop: - scheduler_addr = scheduler_q.get() - - async def _(): - pid = os.getpid() - try: - worker = await Worker(scheduler_addr, validate=True, **kwargs) - except Exception as exc: - q.put((pid, exc)) - else: - q.put((pid, worker.address)) - await worker.finished() - - # Scheduler might've failed - if isinstance(scheduler_addr, str): - try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) - - -@log_errors -def run_nanny(q, scheduler_q, config, **kwargs): - with dask.config.set(config): - with pristine_loop() as loop: scheduler_addr = scheduler_q.get() async def _(): pid = os.getpid() try: - worker = await Nanny(scheduler_addr, validate=True, **kwargs) + worker = await Worker(scheduler_addr, validate=True, **kwargs) except Exception as exc: q.put((pid, exc)) else: @@ -512,14 +513,36 @@ async def _(): # Scheduler might've failed if isinstance(scheduler_addr, str): - try: - loop.run_sync(_) - finally: - loop.close(all_fds=True) + _run_and_close_tornado(_) + + +@log_errors +def run_nanny(q, scheduler_q, config, **kwargs): + with dask.config.set(config): + scheduler_addr = scheduler_q.get() + + async def _(): + pid = os.getpid() + try: + worker = await Nanny(scheduler_addr, validate=True, **kwargs) + except Exception as exc: + q.put((pid, exc)) + else: + q.put((pid, worker.address)) + await worker.finished() + + # Scheduler might've failed + if isinstance(scheduler_addr, str): + _run_and_close_tornado(_) @contextmanager def check_active_rpc(loop, active_rpc_timeout=1): + warnings.warn( + "check_active_rpc is deprecated - use gen_test()", + DeprecationWarning, + stacklevel=2, + ) active_before = set(rpc.active) yield # Some streams can take a bit of time to notice their peer @@ -544,6 +567,29 @@ async def wait(): loop.run_sync(wait) +@contextlib.asynccontextmanager +async def _acheck_active_rpc(active_rpc_timeout=1): + active_before = set(rpc.active) + yield + # Some streams can take a bit of time to notice their peer + # has closed, and keep a coroutine (*) waiting for a CommClosedError + # before calling close_rpc() after a CommClosedError. + # This would happen especially if a non-localhost address is used, + # as Nanny does. + # (*) (example: gather_from_workers()) + + def fail(): + pytest.fail( + "some RPCs left active by test: %s" % (set(rpc.active) - active_before) + ) + + await async_wait_for( + lambda: len(set(rpc.active) - active_before) == 0, + timeout=active_rpc_timeout, + fail_func=fail, + ) + + @pytest.fixture def cluster_fixture(loop): with cluster() as (scheduler, workers): @@ -656,7 +702,7 @@ def cluster( ws = weakref.WeakSet() enable_proctitle_on_children() - with clean(timeout=active_rpc_timeout, threads=False) as loop: + with check_process_leak(check=True), check_instances(), _reconfigure(): if nanny: _run_worker = run_nanny else: @@ -736,7 +782,7 @@ async def wait_for_workers(): if time() - start > 5: raise Exception("Timeout on cluster creation") - loop.run_sync(wait_for_workers) + _run_and_close_tornado(wait_for_workers) # avoid sending processes down to function yield {"address": saddr}, [ @@ -744,26 +790,25 @@ async def wait_for_workers(): for w in workers_by_pid.values() ] finally: - logger.debug("Closing out test cluster") - alive_workers = [ - w["address"] - for w in workers_by_pid.values() - if w["proc"].is_alive() - ] - loop.run_sync( - lambda: disconnect_all( + + async def close(): + logger.debug("Closing out test cluster") + alive_workers = [ + w["address"] + for w in workers_by_pid.values() + if w["proc"].is_alive() + ] + await disconnect_all( alive_workers, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs, ) - ) - if scheduler.is_alive(): - loop.run_sync( - lambda: disconnect( + if scheduler.is_alive(): + await disconnect( saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs ) - ) + _run_and_close_tornado(close) try: client = default_client() except ValueError: @@ -792,7 +837,10 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses)) -def gen_test(timeout: float = _TEST_TIMEOUT) -> Callable[[Callable], Callable]: +def gen_test( + timeout: float = _TEST_TIMEOUT, + clean_kwargs: dict[str, Any] = {}, +) -> Callable[[Callable], Callable]: """Coroutine test @pytest.mark.parametrize("param", [1, 2, 3]) @@ -812,16 +860,20 @@ async def test_foo(): if is_debugging(): timeout = 3600 + async def async_fn_outer(async_fn, /, *args, **kwargs): + async with _acheck_active_rpc(): + return await asyncio.wait_for( + asyncio.create_task(async_fn(*args, **kwargs)), timeout + ) + def _(func): + @functools.wraps(func) + @clean(**clean_kwargs) def test_func(*args, **kwargs): - with clean() as loop: - injected_func = functools.partial(func, *args, **kwargs) - if iscoroutinefunction(func): - cor = injected_func - else: - cor = gen.coroutine(injected_func) + if not iscoroutinefunction(func): + raise RuntimeError("gen_test only works for coroutine functions.") - loop.run_sync(cor, timeout=timeout) + return _run_and_close_tornado(async_fn_outer, func, *args, **kwargs) # Patch the signature so pytest can inject fixtures test_func.__signature__ = inspect.signature(func) @@ -833,7 +885,7 @@ def test_func(*args, **kwargs): async def start_cluster( nthreads: list[tuple[str, int] | tuple[str, int, dict]], scheduler_addr: str, - loop: IOLoop, + loop: IOLoop | None = None, security: Security | dict[str, Any] | None = None, Worker: type[ServerNode] = Worker, scheduler_kwargs: dict[str, Any] = {}, @@ -877,7 +929,7 @@ async def start_cluster( await asyncio.sleep(0.01) if time() > start + 30: await asyncio.gather(*(w.close(timeout=1) for w in workers)) - await s.close(fast=True) + await s.close() check_invalid_worker_transitions(s) check_invalid_task_states(s) check_worker_fail_hard(s) @@ -931,7 +983,7 @@ async def end_cluster(s, workers): async def end_worker(w): with suppress(asyncio.TimeoutError, CommClosedError, EnvironmentError): - await w.close(report=False) + await w.close() await asyncio.gather(*(end_worker(w) for w in workers)) await s.close() # wait until scheduler stops completely @@ -1010,161 +1062,156 @@ def _(func): raise RuntimeError("gen_cluster only works for coroutine functions.") @functools.wraps(func) + @clean(**clean_kwargs) def test_func(*outer_args, **kwargs): - result = None - with clean(timeout=active_rpc_timeout, **clean_kwargs) as loop: - - async def coro(): - with tempfile.TemporaryDirectory() as tmpdir: - config2 = merge({"temporary-directory": tmpdir}, config) - with dask.config.set(config2): - workers = [] - s = False - - for _ in range(60): - try: - s, ws = await start_cluster( - nthreads, - scheduler, - loop, - security=security, - Worker=Worker, - scheduler_kwargs=scheduler_kwargs, - worker_kwargs=worker_kwargs, - ) - except Exception as e: - logger.error( - "Failed to start gen_cluster: " - f"{e.__class__.__name__}: {e}; retrying", - exc_info=True, - ) - await asyncio.sleep(1) - else: - workers[:] = ws - args = [s] + workers - break - if s is False: - raise Exception("Could not start cluster") - if client: - c = await Client( - s.address, - loop=loop, + async def async_fn(): + result = None + with tempfile.TemporaryDirectory() as tmpdir: + config2 = merge({"temporary-directory": tmpdir}, config) + with dask.config.set(config2): + workers = [] + s = False + + for _ in range(60): + try: + s, ws = await start_cluster( + nthreads, + scheduler, security=security, - asynchronous=True, - **client_kwargs, + Worker=Worker, + scheduler_kwargs=scheduler_kwargs, + worker_kwargs=worker_kwargs, ) - args = [c] + args - - try: - coro = func(*args, *outer_args, **kwargs) - task = asyncio.create_task(coro) - coro2 = asyncio.wait_for(asyncio.shield(task), timeout) - result = await coro2 - validate_state(s, *workers) - - except asyncio.TimeoutError: - assert task - buffer = io.StringIO() - # This stack indicates where the coro/test is suspended - task.print_stack(file=buffer) - - if cluster_dump_directory: - await dump_cluster_state( - s, - ws, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - - task.cancel() - while not task.cancelled(): - await asyncio.sleep(0.01) - - # Hopefully, the hang has been caused by inconsistent - # state, which should be much more meaningful than the - # timeout - validate_state(s, *workers) - - # Remove as much of the traceback as possible; it's - # uninteresting boilerplate from utils_test and asyncio - # and not from the code being tested. - raise asyncio.TimeoutError( - f"Test timeout after {timeout}s.\n" - "========== Test stack trace starts here ==========\n" - f"{buffer.getvalue()}" - ) from None - - except pytest.xfail.Exception: - raise - - except Exception: - if cluster_dump_directory and not has_pytestmark( - test_func, "xfail" - ): - await dump_cluster_state( - s, - ws, - output_dir=cluster_dump_directory, - func_name=func.__name__, - ) - raise - - finally: - if client and c.status not in ("closing", "closed"): - await c._close(fast=s.status == Status.closed) - await end_cluster(s, workers) - await asyncio.wait_for(cleanup_global_workers(), 1) - - try: - c = await default_client() - except ValueError: - pass + except Exception as e: + logger.error( + "Failed to start gen_cluster: " + f"{e.__class__.__name__}: {e}; retrying", + exc_info=True, + ) + await asyncio.sleep(1) else: - await c._close(fast=True) - - def get_unclosed(): - return [ - c for c in Comm._instances if not c.closed() - ] + [ - c - for c in _global_clients.values() - if c.status != "closed" - ] + workers[:] = ws + args = [s] + workers + break + if s is False: + raise Exception("Could not start cluster") + if client: + c = await Client( + s.address, + security=security, + asynchronous=True, + **client_kwargs, + ) + args = [c] + args + + try: + coro = func(*args, *outer_args, **kwargs) + task = asyncio.create_task(coro) + coro2 = asyncio.wait_for(asyncio.shield(task), timeout) + result = await coro2 + validate_state(s, *workers) + + except asyncio.TimeoutError: + assert task + buffer = io.StringIO() + # This stack indicates where the coro/test is suspended + task.print_stack(file=buffer) + + if cluster_dump_directory: + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) - try: - start = time() - while time() < start + 60: - gc.collect() - if not get_unclosed(): - break - await asyncio.sleep(0.05) + task.cancel() + while not task.cancelled(): + await asyncio.sleep(0.01) + + # Hopefully, the hang has been caused by inconsistent + # state, which should be much more meaningful than the + # timeout + validate_state(s, *workers) + + # Remove as much of the traceback as possible; it's + # uninteresting boilerplate from utils_test and asyncio + # and not from the code being tested. + raise asyncio.TimeoutError( + f"Test timeout after {timeout}s.\n" + "========== Test stack trace starts here ==========\n" + f"{buffer.getvalue()}" + ) from None + + except pytest.xfail.Exception: + raise + + except Exception: + if cluster_dump_directory and not has_pytestmark( + test_func, "xfail" + ): + await dump_cluster_state( + s, + ws, + output_dir=cluster_dump_directory, + func_name=func.__name__, + ) + raise + + finally: + if client and c.status not in ("closing", "closed"): + await c._close(fast=s.status == Status.closed) + await end_cluster(s, workers) + await asyncio.wait_for(cleanup_global_workers(), 1) + + try: + c = await default_client() + except ValueError: + pass + else: + await c._close(fast=True) + + def get_unclosed(): + return [c for c in Comm._instances if not c.closed()] + [ + c + for c in _global_clients.values() + if c.status != "closed" + ] + + try: + start = time() + while time() < start + 60: + gc.collect() + if not get_unclosed(): + break + await asyncio.sleep(0.05) + else: + if allow_unclosed: + print(f"Unclosed Comms: {get_unclosed()}") else: - if allow_unclosed: - print(f"Unclosed Comms: {get_unclosed()}") - else: - raise RuntimeError( - "Unclosed Comms", get_unclosed() - ) - finally: - Comm._instances.clear() - _global_clients.clear() - - for w in workers: - if getattr(w, "data", None): - try: - w.data.clear() - except OSError: - # zict backends can fail if their storage directory - # was already removed - pass - - return result - - result = loop.run_sync( - coro, timeout=timeout * 2 if timeout else timeout - ) - - return result + raise RuntimeError("Unclosed Comms", get_unclosed()) + finally: + Comm._instances.clear() + _global_clients.clear() + + for w in workers: + if getattr(w, "data", None): + try: + w.data.clear() + except OSError: + # zict backends can fail if their storage directory + # was already removed + pass + + return result + + async def async_fn_outer(): + async with _acheck_active_rpc(active_rpc_timeout=active_rpc_timeout): + if timeout: + return await asyncio.wait_for(async_fn(), timeout=timeout * 2) + return await async_fn() + + return _run_and_close_tornado(async_fn_outer) # Patch the signature so pytest can inject fixtures orig_sig = inspect.signature(func) @@ -1253,7 +1300,9 @@ def _terminate_process(proc): @contextmanager -def popen(args: list[str], flush_output: bool = True, **kwargs): +def popen( + args: list[str], flush_output: bool = True, **kwargs +) -> Iterator[subprocess.Popen[bytes]]: """Start a shell command in a subprocess. Yields a subprocess.Popen object. @@ -1795,7 +1844,7 @@ def check_instances(): for w in Worker._instances: with suppress(RuntimeError): # closed IOLoop - w.loop.add_callback(w.close, report=False, executor_wait=False) + w.loop.add_callback(w.close, executor_wait=False) if w.status in WORKER_ANY_RUNNING: w.loop.add_callback(w.close) Worker._instances.clear() @@ -1832,26 +1881,31 @@ def check_instances(): @contextmanager -def clean(threads=True, instances=True, timeout=1, processes=True): - with check_thread_leak() if threads else nullcontext(): - with pristine_loop() as loop: - with check_process_leak(check=processes): - with check_instances() if instances else nullcontext(): - with check_active_rpc(loop, timeout): - reset_config() +def _reconfigure(): + reset_config() - with dask.config.set( - { - "distributed.comm.timeouts.connect": "5s", - "distributed.admin.tick.interval": "500 ms", - } - ): - # Restore default logging levels - # XXX use pytest hooks/fixtures instead? - for name, level in logging_levels.items(): - logging.getLogger(name).setLevel(level) + with dask.config.set( + { + "distributed.comm.timeouts.connect": "5s", + "distributed.admin.tick.interval": "500 ms", + } + ): + # Restore default logging levels + # XXX use pytest hooks/fixtures instead? + for name, level in logging_levels.items(): + logging.getLogger(name).setLevel(level) - yield loop + yield + + +@contextmanager +def clean(threads=True, instances=True, processes=True): + asyncio.set_event_loop(None) + with check_thread_leak() if threads else nullcontext(): + with check_process_leak(check=processes): + with check_instances() if instances else nullcontext(): + with _reconfigure(): + yield @pytest.fixture @@ -2149,3 +2203,61 @@ def raises_with_cause( assert re.search( match_cause, str(exc.__cause__) ), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``" + + +def ucx_exception_handler(loop, context): + """UCX exception handler for `ucx_loop` during test. + + Prints the exception and its message. + + Parameters + ---------- + loop: object + Reference to the running event loop + context: dict + Dictionary containing exception details. + """ + msg = context.get("exception", context["message"]) + print(msg) + + +# Let's make sure that UCX gets time to cancel +# progress tasks before closing the event loop. +@pytest.fixture(scope="function") +def ucx_loop(): + """Allows UCX to cancel progress tasks before closing event loop. + + When UCX tasks are not completed in time (e.g., by unexpected Endpoint + closure), clean up tasks before closing the event loop to prevent unwanted + errors from being raised. + """ + ucp = pytest.importorskip("ucp") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(ucx_exception_handler) + ucp.reset() + yield loop + ucp.reset() + loop.close() + + +def wait_for_log_line( + match: bytes, stream: IO[bytes] | None, max_lines: int | None = 10 +) -> bytes: + """ + Read lines from an IO stream until the match is found, and return the matching line. + + Prints each line to test stdout for easier debugging of failures. + """ + assert stream + i = 0 + while True: + if max_lines is not None and i == max_lines: + raise AssertionError( + f"{match!r} not found in {max_lines} log lines. See test stdout for details." + ) + line = stream.readline() + print(line) + if match in line: + return line + i += 1 diff --git a/distributed/worker.py b/distributed/worker.py index 1d90dd70123..e6585d286e2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -189,6 +189,7 @@ async def wrapper(self, *args, **kwargs): ) logger.exception(e) await _force_close(self) + raise else: @@ -207,6 +208,7 @@ def wrapper(self, *args, **kwargs): ) logger.exception(e) self.loop.add_callback(_force_close, self) + raise return wrapper @@ -1219,7 +1221,7 @@ async def heartbeat(self): logger.error( f"Scheduler was unaware of this worker {self.address!r}. Shutting down." ) - await self.close(report=False) + await self.close() return self.scheduler_delay = response["time"] - middle @@ -1230,12 +1232,12 @@ async def heartbeat(self): self.bandwidth_types.clear() except CommClosedError: logger.warning("Heartbeat to scheduler failed", exc_info=True) - await self.close(report=False) + await self.close() except OSError as e: # Scheduler is gone. Respect distributed.comm.timeouts.connect if "Timed out trying to connect" in str(e): logger.info("Timed out while trying to connect during heartbeat") - await self.close(report=False) + await self.close() else: logger.exception(e) raise e @@ -1251,7 +1253,7 @@ async def handle_scheduler(self, comm): self.address, self.status, ) - await self.close(report=False) + await self.close() async def upload_file(self, comm, filename=None, data=None, load=True): out_filename = os.path.join(self.local_directory, filename) @@ -1441,12 +1443,30 @@ async def start_unsafe(self): @log_errors async def close( - self, report=True, timeout=30, nanny=True, executor_wait=True, safe=False + self, + timeout=30, + executor_wait=True, + nanny=True, ): - if self.status in (Status.closed, Status.closing): + # FIXME: The worker should not be allowed to close the nanny. Ownership + # is the other way round. If an external caller wants to close + # nanny+worker, the nanny must be notified first. ==> Remove kwarg + # nanny, see also Scheduler.retire_workers + if self.status in (Status.closed, Status.closing, Status.failed): await self.finished() return + if self.status == Status.init: + # If the worker is still in startup/init and is started by a nanny, + # this means the nanny itself is not up, yet. If the Nanny isn't up, + # yet, it's server will not accept any incoming RPC requests and + # will block until the startup is finished. + # Therefore, this worker trying to communicate with the Nanny during + # startup is not possible and we cannot close it. + # In this case, the Nanny will automatically close after inspecting + # the worker status + nanny = False + disable_gc_diagnosis() try: @@ -1455,8 +1475,6 @@ async def close( logger.info("Stopping worker") if self.status not in WORKER_ANY_RUNNING: logger.info("Closed worker has not yet started: %s", self.status) - if not report: - logger.info("Not reporting worker closure to scheduler") if not executor_wait: logger.info("Not waiting on executor to close") self.status = Status.closing @@ -1473,7 +1491,10 @@ async def close( ) for preload in self.preloads: - await preload.teardown() + try: + await preload.teardown() + except Exception as e: + logger.exception(e) for extension in self.extensions.values(): if hasattr(extension, "close"): @@ -1520,16 +1541,6 @@ async def close( # otherwise c.close() - with suppress(EnvironmentError, TimeoutError): - if report and self.contact_address is not None: - await asyncio.wait_for( - self.scheduler.unregister( - address=self.contact_address, - safe=safe, - stimulus_id=f"worker-close-{time()}", - ), - timeout, - ) await self.scheduler.close_rpc() self._workdir.release() @@ -1607,7 +1618,7 @@ async def close_gracefully(self, restart=None): remove=False, stimulus_id=f"worker-close-gracefully-{time()}", ) - await self.close(safe=True, nanny=not restart) + await self.close(nanny=not restart) async def wait_until_closed(self): warnings.warn("wait_until_closed has moved to finished()") @@ -3612,36 +3623,22 @@ def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: return {"status": "error", "exception": to_serialize(ex)} async def _maybe_deserialize_task( - self, ts: TaskState, *, stimulus_id: str - ) -> tuple[Callable, tuple, dict[str, Any]] | None: - if ts.run_spec is None: - return None - try: - start = time() - # Offload deserializing large tasks - if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: - function, args, kwargs = await offload(_deserialize, *ts.run_spec) - else: - function, args, kwargs = _deserialize(*ts.run_spec) - stop = time() + self, ts: TaskState + ) -> tuple[Callable, tuple, dict[str, Any]]: + assert ts.run_spec is not None + start = time() + # Offload deserializing large tasks + if sizeof(ts.run_spec) > OFFLOAD_THRESHOLD: + function, args, kwargs = await offload(_deserialize, *ts.run_spec) + else: + function, args, kwargs = _deserialize(*ts.run_spec) + stop = time() - if stop - start > 0.010: - ts.startstops.append( - {"action": "deserialize", "start": start, "stop": stop} - ) - return function, args, kwargs - except Exception as e: - logger.error("Could not deserialize task", exc_info=True) - self.log.append((ts.key, "deserialize-error", stimulus_id, time())) - emsg = error_message(e) - del emsg["status"] # type: ignore - self.transition( - ts, - "error", - **emsg, - stimulus_id=stimulus_id, + if stop - start > 0.010: + ts.startstops.append( + {"action": "deserialize", "start": start, "stop": stop} ) - raise + return function, args, kwargs def _ensure_computing(self) -> RecsInstrs: if self.status != Status.running: @@ -3714,16 +3711,22 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No ) return AlreadyCancelledEvent(key=ts.key, stimulus_id=stimulus_id) + try: + function, args, kwargs = await self._maybe_deserialize_task(ts) + except Exception as exc: + logger.error("Could not deserialize task %s", key, exc_info=True) + return ExecuteFailureEvent.from_exception( + exc, + key=key, + stimulus_id=f"run-spec-deserialize-failed-{time()}", + ) + try: if self.validate: assert not ts.waiting_for_data assert ts.state == "executing" assert ts.run_spec is not None - function, args, kwargs = await self._maybe_deserialize_task( # type: ignore - ts, stimulus_id=stimulus_id - ) - args2, kwargs2 = self._prepare_args_for_execution(ts, args, kwargs) try: @@ -3804,29 +3807,20 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No convert_kwargs_to_str(kwargs2, max_len=1000), result["exception_text"], ) - return ExecuteFailureEvent( + return ExecuteFailureEvent.from_exception( + result, key=key, start=result["start"], stop=result["stop"], - exception=result["exception"], - traceback=result["traceback"], - exception_text=result["exception_text"], - traceback_text=result["traceback_text"], stimulus_id=f"task-erred-{time()}", ) except Exception as exc: logger.error("Exception during execution of task %s.", key, exc_info=True) - msg = error_message(exc) - return ExecuteFailureEvent( + return ExecuteFailureEvent.from_exception( + exc, key=key, - start=None, - stop=None, - exception=msg["exception"], - traceback=msg["traceback"], - exception_text=msg["exception_text"], - traceback_text=msg["traceback_text"], - stimulus_id=f"task-erred-{time()}", + stimulus_id=f"execute-unknown-error-{time()}", ) @functools.singledispatchmethod diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 5397f7bac66..5132afb2a3e 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -38,6 +38,7 @@ from dask.utils import format_bytes, parse_bytes, parse_timedelta from distributed import system +from distributed.compatibility import WINDOWS from distributed.core import Status from distributed.metrics import monotonic from distributed.spill import ManualEvictProto, SpillBuffer @@ -333,25 +334,24 @@ def __init__( def memory_monitor(self, nanny: Nanny) -> None: """Track worker's memory. Restart if it goes above terminate fraction.""" - if nanny.status != Status.running: - return # pragma: nocover - if nanny.process is None or nanny.process.process is None: + if ( + nanny.status != Status.running + or nanny.process is None + or nanny.process.process is None + or nanny.process.process.pid is None + ): return # pragma: nocover + process = nanny.process.process try: - proc = nanny._psutil_process - memory = proc.memory_info().rss + memory = psutil.Process(process.pid).memory_info().rss except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return # pragma: nocover - if process.pid in (self._last_terminated_pid, None): - # We already sent SIGTERM to the worker, but its handler is still running - # since the previous iteration of the memory_monitor - for example, it - # may be taking a long time deleting all the spilled data from disk. + if memory / self.memory_limit <= self.memory_terminate_fraction: return - self._last_terminated_pid = -1 - if memory / self.memory_limit > self.memory_terminate_fraction: + if self._last_terminated_pid != process.pid: logger.warning( f"Worker {nanny.worker_address} (pid={process.pid}) exceeded " f"{self.memory_terminate_fraction * 100:.0f}% memory budget. " @@ -359,6 +359,29 @@ def memory_monitor(self, nanny: Nanny) -> None: ) self._last_terminated_pid = process.pid process.terminate() + else: + # We already sent SIGTERM to the worker, but the process is still alive + # since the previous iteration of the memory_monitor - for example, some + # user code may have tampered with signal handlers. + # Send SIGKILL for immediate termination. + # + # Note that this should not be a disk-related issue. Unlike in a regular + # worker shutdown, where the worker cleans up its own spill directory, in + # case of SIGTERM no atexit or weakref.finalize callback is triggered + # whatsoever; instead, the nanny cleans up the spill directory *after* the + # worker has been shut down and before starting a new one. + # This is important, as spill directory cleanup may potentially take tens of + # seconds and, if the worker did it, any task that was running and leaking + # would continue to do so for the whole duration of the cleanup, increasing + # the risk of going beyond 100%. + logger.warning( + f"Worker {nanny.worker_address} (pid={process.pid}) is slow to %s", + # On Windows, kill() is an alias to terminate() + "terminate; trying again" + if WINDOWS + else "accept SIGTERM; sending SIGKILL", + ) + process.kill() def parse_memory_limit( diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index e1a4cd3e70d..7ec24c1a608 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -12,6 +12,7 @@ import dask from dask.utils import parse_bytes +from distributed.core import ErrorMessage, error_message from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict @@ -474,7 +475,6 @@ class ExecuteSuccessEvent(StateMachineEvent): stop: float nbytes: int type: type | None - stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def to_loggable(self, *, handled: float) -> StateMachineEvent: @@ -497,13 +497,38 @@ class ExecuteFailureEvent(StateMachineEvent): traceback: Serialize | None exception_text: str traceback_text: str - stimulus_id: str __slots__ = tuple(__annotations__) # type: ignore def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) self.traceback = None + @classmethod + def from_exception( + cls, + err_or_msg: BaseException | ErrorMessage, + *, + key: str, + start: float | None = None, + stop: float | None = None, + stimulus_id: str, + ) -> ExecuteFailureEvent: + if isinstance(err_or_msg, dict): + msg = err_or_msg + else: + msg = error_message(err_or_msg) + + return cls( + key=key, + start=start, + stop=stop, + exception=msg["exception"], + traceback=msg["traceback"], + exception_text=msg["exception_text"], + traceback_text=msg["traceback_text"], + stimulus_id=stimulus_id, + ) + @dataclass class CancelComputeEvent(StateMachineEvent): diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 2a688b9097f..c318aa1a405 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,125 @@ Changelog ========= +.. _v2022.05.2: + +2022.05.2 +--------- + +Released on May 26, 2022 + +Enhancements +^^^^^^^^^^^^ +- Add a lock to ``distributed.profile`` for better concurrency control (:pr:`6421`) `Hendrik Makait`_ +- Send ``SIGKILL`` after ``SIGTERM`` when passing 95% memory (:pr:`6419`) `crusaderky`_ + +Bug Fixes +^^^^^^^^^ +- Log rather than raise exceptions in ``preload.teardown()`` (:pr:`6458`) `Matthew Rocklin`_ +- Handle failing ``plugin.close()`` calls during scheduler shutdown (:pr:`6450`) `Matthew Rocklin`_ +- Fix slicing bug in ``ensure_memoryview`` (:pr:`6449`) `jakirkham`_ +- Generalize UCX errors on ``connect()`` and correct pytest fixtures (:pr:`6434`) `Peter Andreas Entschev`_ +- Run cluster widget periodic callbacks on the correct event loop (:pr:`6444`) `Thomas Grainger`_ + +Maintenance +^^^^^^^^^^^ +- Disable ``pytest-asyncio`` if installed (:pr:`6436`) `Jacob Tomlinson`_ +- Close client in sync test_actor tests (:pr:`6459`) `Thomas Grainger`_ +- Ignore ``ServerSession.with_document_locked unawaited`` (:pr:`6447`) `Thomas Grainger`_ +- Remove ``coverage`` pin from Python 3.10 environment (:pr:`6439`) `Thomas Grainger`_ +- Annotate ``remove_worker`` (:pr:`6441`) `crusaderky`_ +- Update gpuCI ``RAPIDS_VER`` to ``22.08`` (:pr:`6428`) + + +.. _v2022.05.1: + +2022.05.1 +--------- + +Released on May 24, 2022 + +New Features +^^^^^^^^^^^^ +- Add HTTP API to scheduler (:pr:`6270`) `Matthew Murray`_ +- Shuffle Service with Scheduler Logic (:pr:`6007`) `Matthew Rocklin`_ + +Enhancements +^^^^^^^^^^^^ +- Follow-up on removing ``report`` and ``safe`` from ``Worker.close`` (:pr:`6423`) `Gabe Joseph`_ +- Server close faster (:pr:`6415`) `Florian Jetter`_ +- Disable HTTP API by default (:pr:`6420`) `Jacob Tomlinson`_ +- Remove ``report`` and ``safe`` from ``Worker.close`` (:pr:`6363`) `Florian Jetter`_ +- Allow deserialized plugins in ``register_scheduler_plugin`` (:pr:`6401`) `Matthew Rocklin`_ +- ``WorkerState`` are different for different addresses (:pr:`6398`) `Florian Jetter`_ +- Do not filter tasks before gathering data (:pr:`6371`) `crusaderky`_ +- Remove worker reconnect (:pr:`6361`) `Gabe Joseph`_ +- Add ``SchedulerPlugin.log_event handler`` (:pr:`6381`) `Matthew Rocklin`_ +- Ensure occupancy tracking works as expected for long running tasks (:pr:`6351`) `Florian Jetter`_ +- ``stimulus_id`` for all ``Instructions`` (:pr:`6347`) `crusaderky`_ +- Refactor missing-data command (:pr:`6332`) `crusaderky`_ +- Add ``idempotent`` to ``register_scheduler_plugin`` client (:pr:`6328`) `Alex Ford`_ +- Add option to specify a scheduler address for workers to use (:pr:`5944`) `Enric Tejedor`_ + +Bug Fixes +^^^^^^^^^ +- Remove stray ``breakpoint`` (:pr:`6417`) `Thomas Grainger`_ +- Fix API JSON MIME type (:pr:`6397`) `Jacob Tomlinson`_ +- Remove wrong ``assert`` in handle compute (:pr:`6370`) `Florian Jetter`_ +- Ensure multiple clients can cancel their key without interference (:pr:`6016`) `Florian Jetter`_ +- Fix ``Nanny`` shutdown assertion (:pr:`6357`) `Gabe Joseph`_ +- Fix ``fail_hard`` for sync functions (:pr:`6269`) `Gabe Joseph`_ +- Prevent infinite transition loops; more aggressive ``validate_state()`` (:pr:`6318`) `crusaderky`_ +- Ensure cleanup of many GBs of spilled data on terminate (:pr:`6280`) `crusaderky`_ +- Fix ``WORKER_ANY_RUNNING`` regression (:pr:`6297`) `Florian Jetter`_ +- Race conditions from fetch to compute while AMM requests replica (:pr:`6248`) `Florian Jetter`_ +- Ensure resumed tasks are not accidentally forgotten (:pr:`6217`) `Florian Jetter`_ +- Do not allow closing workers to be awaited again (:pr:`5910`) `Florian Jetter`_ + +Deprecations +^^^^^^^^^^^^ +- Move ``wait_for_signals`` to private module and deprecate ``distributed.cli.utils`` (:pr:`6367`) `Hendrik Makait`_ + +Documentation +^^^^^^^^^^^^^ +- Fix typos and whitespace in ``worker.py`` (:pr:`6326`) `Hendrik Makait`_ +- Fix link to memory trimming documentation (:pr:`6317`) `Marco Wolsza`_ + +Maintenance +^^^^^^^^^^^ +- Make ``gen_test`` show up in VSCode test discovery (:pr:`6424`) `Gabe Joseph`_ +- WSMR / ``deserialize_task`` (:pr:`6411`) `crusaderky`_ +- Restore signal handlers after wait for signals is done (:pr:`6400`) `Thomas Grainger`_ +- ``fail_hard`` should reraise (:pr:`6399`) `crusaderky`_ +- Revisit tests mocking ``gather_dep`` (:pr:`6385`) `crusaderky`_ +- Fix flaky ``test_in_flight_lost_after_resumed`` (:pr:`6372`) `Florian Jetter`_ +- Restore install_signal_handlers due to downstream dependencies (:pr:`6366`) `Hendrik Makait`_ +- Improve ``catch_unhandled_exceptions`` (:pr:`6358`) `Gabe Joseph`_ +- Remove all invocations of ``IOLoop.run_sync`` from CLI (:pr:`6205`) `Hendrik Makait`_ +- Remove ``transition-counter-max`` from config (:pr:`6349`) `crusaderky`_ +- Use ``list`` comprehension in ``pickle_loads`` (:pr:`6343`) `jakirkham`_ +- Improve ``ensure_memoryview`` test coverage & make minor fixes (:pr:`6333`) `jakirkham`_ +- Remove leaking reference to ``workers`` from ``gen_cluster`` (:pr:`6337`) `Hendrik Makait`_ +- Partial annotations for ``stealing.py`` (:pr:`6338`) `crusaderky`_ +- Validate and debug state machine on ``handle_compute_task`` (:pr:`6327`) `crusaderky`_ +- Bump pyupgrade and clean up ``# type: ignore`` (:pr:`6293`) `crusaderky`_ +- ``gen_cluster`` to write to ``/tmp`` (:pr:`6335`) `crusaderky`_ +- Transition table as a ``ClassVar`` (:pr:`6331`) `crusaderky`_ +- Simplify ``ensure_memoryview`` test with ``array`` (:pr:`6322`) `jakirkham`_ +- Refactor ``ensure_communicating`` (:pr:`6165`) `crusaderky`_ +- Review scheduler annotations, part 2 (:pr:`6253`) `crusaderky`_ +- Use ``w`` for ``writeable`` branch in ``pickle_loads`` (:pr:`6314`) `jakirkham`_ +- Simplify frame handling in ``ws`` (:pr:`6294`) `jakirkham`_ +- Use ``ensure_bytes`` from ``dask.utils`` (:pr:`6295`) `jakirkham`_ +- Use ``ensure_memoryview`` in ``array`` deserialization (:pr:`6300`) `jakirkham`_ +- Escape < > when generating Junit report (:pr:`6306`) `crusaderky`_ +- Use ``codecs.decode`` to deserialize errors (:pr:`6274`) `jakirkham`_ +- Minimize copying in ``maybe_compress`` & ``byte_sample`` (:pr:`6273`) `jakirkham`_ +- Skip ``test_release_evloop_while_spilling`` on OSX (:pr:`6291`) `Florian Jetter`_ +- Simplify logic in ``get_default_compression`` (:pr:`6260`) `jakirkham`_ +- Cleanup old compression workarounds (:pr:`6259`) `jakirkham`_ +- Re-enable NVML monitoring for WSL (:pr:`6119`) `Charles Blackmon-Luca`_ + + .. _v2022.05.0: 2022.05.0 @@ -2847,7 +2966,7 @@ This is a small bugfix release due to a config change upstream. - Fixed an uncaught exception in ``distributed.joblib`` with a ``LocalCluster`` using only threads (:issue:`1775`) `Tom Augspurger`_ - Format bytes in info worker page (:pr:`1752`) `Matthew Rocklin`_ -- Add pass-through arguments for scheduler/worker `--preload` modules. (:pr:`1634`) `Alexander Ford`_ +- Add pass-through arguments for scheduler/worker `--preload` modules. (:pr:`1634`) `Alex Ford`_ - Use new LZ4 API (:pr:`1757`) `Thrasibule`_ - Replace dask.optimize with dask.optimization (:pr:`1754`) `Matthew Rocklin`_ - Add graph layout engine and bokeh plot (:pr:`1756`) `Matthew Rocklin`_ @@ -2879,7 +2998,7 @@ This is a small bugfix release due to a config change upstream. - Ensure dumps_function works with unhashable functions (:pr:`1662`) `Matthew Rocklin`_ - Collect client name ids rom client-name config variable (:pr:`1664`) `Matthew Rocklin`_ - Allow simultaneous use of --name and --nprocs in dask-worker (:pr:`1665`) `Matthew Rocklin`_ -- Add support for grouped adaptive scaling and adaptive behavior overrides (:pr:`1632`) `Alexander Ford`_ +- Add support for grouped adaptive scaling and adaptive behavior overrides (:pr:`1632`) `Alex Ford`_ - Share scheduler RPC between worker and client (:pr:`1673`) `Matthew Rocklin`_ - Allow ``retries=`` in ClientExecutor (:pr:`1672`) `@rqx`_ - Improve documentation for get_client and dask.compute examples (:pr:`1638`) `Scott Sievert`_ @@ -3409,7 +3528,7 @@ significantly without many new features. .. _`Daniel Li`: https://github.com/li-dan .. _`Brett Naul`: https://github.com/bnaul .. _`Cornelius Riemenschneider`: https://github.com/corni -.. _`Alexander Ford`: https://github.com/asford +.. _`Alex Ford`: https://github.com/asford .. _`@rqx`: https://github.com/rqx .. _`Min RK`: https://github.comminrk/ .. _`Bruce Merry`: https://github.com/bmerry @@ -3607,3 +3726,7 @@ significantly without many new features. .. _`Duncan McGregor`: https://github.com/dmcg .. _`Eric Engestrom`: https://github.com/lace .. _`ungarj`: https://github.com/ungarj +.. _`Matthew Murray`: https://github.com/Matt711 +.. _`Enric Tejedor`: https://github.com/etejedor +.. _`Hendrik Makait`: https://github.com/hendrikmakait +.. _`Marco Wolsza`: https://github.com/maawoo diff --git a/requirements.txt b/requirements.txt index 8b41f991781..abe97b5caef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 1.5.0 -dask == 2022.05.0 +dask == 2022.05.2 jinja2 locket >= 1.0.0 msgpack >= 0.6.0 diff --git a/setup.cfg b/setup.cfg index d55d03bea5b..1d58d8b0ccf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,7 +39,16 @@ tag_prefix = parentdir_prefix = distributed- [tool:pytest] -addopts = -v -rsxfE --durations=20 --color=yes --ignore=continuous_integration --ignore=docs --ignore=.github --strict-markers --strict-config +addopts = + -v -rsxfE + --durations=20 + --color=yes + --ignore=continuous_integration + --ignore=docs + --ignore=.github + --strict-markers + --strict-config + -p no:asyncio filterwarnings = error ignore:Please use `dok_matrix` from the `scipy\.sparse` namespace, the `scipy\.sparse\.dok` namespace is deprecated.:DeprecationWarning @@ -60,6 +69,7 @@ filterwarnings = ignore:coroutine 'PooledRPCCall\.__getattr__\.\.send_recv_from_rpc' was never awaited:RuntimeWarning ignore:coroutine 'Scheduler\.restart' was never awaited:RuntimeWarning ignore:coroutine 'Semaphore._refresh_leases' was never awaited:RuntimeWarning + ignore:coroutine 'ServerSession\.with_document_locked' was never awaited ignore:overflow encountered in long_scalars:RuntimeWarning ignore:Creating scratch directories is taking a surprisingly long time.*:UserWarning ignore:Running on a single-machine scheduler when a distributed client is active might lead to unexpected results\.:UserWarning @@ -75,7 +85,6 @@ filterwarnings = ignore:(?s)Exception ignored in. ._needs_document_lock_wrapper' was never awaited minversion = 6 markers = ci1: marks tests as belonging to 1 out of 2 partitions to run on CI ('-m "not ci1"' for second partition)