From 7680b85aa598582f347dc63b5d676302d7475ea3 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 4 Mar 2024 15:43:30 +0000 Subject: [PATCH] Refactor restart() and restart_workers() --- distributed/client.py | 80 ++++--- .../diagnostics/tests/test_progress.py | 4 +- distributed/nanny.py | 1 + distributed/scheduler.py | 226 +++++++++++++----- distributed/tests/test_client.py | 63 +++-- distributed/tests/test_failed_workers.py | 5 +- distributed/tests/test_nanny.py | 17 +- distributed/tests/test_scheduler.py | 15 +- distributed/utils_test.py | 24 ++ 9 files changed, 281 insertions(+), 154 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 46982203530..27d4d81fee3 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -35,7 +35,7 @@ from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph from dask.optimization import SubgraphCallable -from dask.typing import no_default +from dask.typing import NoDefault, no_default from dask.utils import ( apply, ensure_dict, @@ -48,7 +48,7 @@ ) from dask.widgets import get_template -from distributed.core import ErrorMessage, OKMessage +from distributed.core import OKMessage from distributed.protocol.serialize import _is_dumpable from distributed.utils import Deadline, wait_for @@ -858,8 +858,9 @@ def __init__( ): if timeout is no_default: timeout = dask.config.get("distributed.comm.timeouts.connect") - if timeout is not None: - timeout = parse_timedelta(timeout, "s") + timeout = parse_timedelta(timeout, "s") + if timeout is None: + raise ValueError("None is an invalid value for global client timeout") self._timeout = timeout self.futures = dict() @@ -1252,8 +1253,7 @@ async def _start(self, timeout=no_default, **kwargs): if timeout is no_default: timeout = self._timeout - if timeout is not None: - timeout = parse_timedelta(timeout, "s") + timeout = parse_timedelta(timeout, "s") address = self._start_arg if self.cluster is not None: @@ -3593,16 +3593,24 @@ def persist( else: return result - async def _restart(self, timeout=no_default, wait_for_workers=True): + async def _restart( + self, timeout: str | int | float | NoDefault, wait_for_workers: bool + ) -> None: if timeout is no_default: timeout = self._timeout * 4 - if timeout is not None: - timeout = parse_timedelta(timeout, "s") + timeout = parse_timedelta(cast("str|int|float", timeout), "s") - await self.scheduler.restart(timeout=timeout, wait_for_workers=wait_for_workers) - return self + await self.scheduler.restart( + timeout=timeout, + wait_for_workers=wait_for_workers, + stimulus_id=f"client-restart-{time()}", + ) - def restart(self, timeout=no_default, wait_for_workers=True): + def restart( + self, + timeout: str | int | float | NoDefault = no_default, + wait_for_workers: bool = True, + ): """ Restart all workers. Reset local state. Optionally wait for workers to return. @@ -3639,46 +3647,43 @@ def restart(self, timeout=no_default, wait_for_workers=True): async def _restart_workers( self, workers: list[str], - timeout: int | float | None = None, - raise_for_error: bool = True, - ) -> dict[str, str | ErrorMessage]: + timeout: str | int | float | NoDefault, + raise_for_error: bool, + ) -> dict[str, Literal["OK", "timed out"]]: + if timeout is no_default: + timeout = self._timeout * 4 + timeout = parse_timedelta(cast("str|int|float", timeout), "s") + info = self.scheduler_info() name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()} worker_addrs = [name_to_addr.get(w, w) for w in workers] - restart_out: dict[str, str | ErrorMessage] = await self.scheduler.broadcast( - msg={"op": "restart", "timeout": timeout}, + out: dict[ + str, Literal["OK", "timed out"] + ] = await self.scheduler.restart_workers( workers=worker_addrs, - nanny=True, + timeout=timeout, + on_error="raise" if raise_for_error else "return", + stimulus_id=f"client-restart-workers-{time()}", ) - # Map keys back to original `workers` input names/addresses - results = {w: restart_out[w_addr] for w, w_addr in zip(workers, worker_addrs)} - - timeout_workers = [w for w, status in results.items() if status == "timed out"] - if timeout_workers and raise_for_error: - raise TimeoutError( - f"The following workers failed to restart with {timeout} seconds: {timeout_workers}" - ) - - errored: list[ErrorMessage] = [m for m in results.values() if "exception" in m] # type: ignore - if errored and raise_for_error: - raise pickle.loads(errored[0]["exception"]) # type: ignore - return results + out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)} + if raise_for_error: + assert all(v == "OK" for v in out.values()), out + return out def restart_workers( self, workers: list[str], - timeout: int | float | None = None, + timeout: str | int | float | NoDefault = no_default, raise_for_error: bool = True, - ) -> dict[str, str]: + ): """Restart a specified set of workers .. note:: Only workers being monitored by a :class:`distributed.Nanny` can be restarted. - - See ``Nanny.restart`` for more details. + See ``Nanny.restart`` for more details. Parameters ---------- @@ -3693,7 +3698,7 @@ def restart_workers( Returns ------- - dict[str, str] + dict[str, "OK" | "timed out"] Mapping of worker and restart status, the keys will match the original values passed in via ``workers``. @@ -3727,7 +3732,8 @@ def restart_workers( for worker, meta in info["workers"].items(): if (worker in workers or meta["name"] in workers) and meta["nanny"] is None: raise ValueError( - f"Restarting workers requires a nanny to be used. Worker {worker} has type {info['workers'][worker]['type']}." + f"Restarting workers requires a nanny to be used. Worker " + f"{worker} has type {info['workers'][worker]['type']}." ) return self.sync( self._restart_workers, diff --git a/distributed/diagnostics/tests/test_progress.py b/distributed/diagnostics/tests/test_progress.py index 7bad80da7e2..2d4310b992f 100644 --- a/distributed/diagnostics/tests/test_progress.py +++ b/distributed/diagnostics/tests/test_progress.py @@ -210,7 +210,7 @@ def f(x): await wait([future]) assert p.state["memory"] == {"f": {future.key}} - await c._restart() + await c.restart() for coll in [p.all] + list(p.state.values()): assert not coll @@ -262,7 +262,7 @@ async def test_group_timing(c, s, a, b): ] ) - await s.restart() + await s.restart(stimulus_id="test") assert len(p.time) == 2 assert len(p.nthreads) == 2 assert len(p.compute) == 0 diff --git a/distributed/nanny.py b/distributed/nanny.py index 2c14224d363..99644e9292d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -851,6 +851,7 @@ async def kill( assert self.status in ( Status.running, Status.failed, # process failed to start, but hasn't been joined yet + Status.closing_gracefully, ), self.status self.status = Status.stopping logger.info("Nanny asking worker to close. Reason: %s", reason) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 320edc391a4..c9fa0413f70 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3813,6 +3813,7 @@ async def post(self): "replicate": self.replicate, "run_function": self.run_function, "restart": self.restart, + "restart_workers": self.restart_workers, "update_data": self.update_data, "set_resources": self.add_resources, "retire_workers": self.retire_workers, @@ -6176,39 +6177,30 @@ async def gather( return {"status": "error", "keys": list(failed_keys)} @log_errors - async def restart(self, client=None, timeout=30, wait_for_workers=True): - """ - Restart all workers. Reset local state. Optionally wait for workers to return. - - Workers without nannies are shut down, hoping an external deployment system - will restart them. Therefore, if not using nannies and your deployment system - does not automatically restart workers, ``restart`` will just shut down all - workers, then time out! - - After ``restart``, all connected workers are new, regardless of whether ``TimeoutError`` - was raised. Any workers that failed to shut down in time are removed, and - may or may not shut down on their own in the future. + async def restart( + self, + *, + client: str | None = None, + timeout: float = 30, + wait_for_workers: bool = True, + stimulus_id: str, + ) -> None: + """Forget all tasks and call restart_workers on all workers. Parameters ---------- timeout: - How long to wait for workers to shut down and come back, if ``wait_for_workers`` - is True, otherwise just how long to wait for workers to shut down. - Raises ``asyncio.TimeoutError`` if this is exceeded. + See restart_workers wait_for_workers: - Whether to wait for all workers to reconnect, or just for them to shut down - (default True). Use ``restart(wait_for_workers=False)`` combined with - :meth:`Client.wait_for_workers` for granular control over how many workers to - wait for. + See restart_workers See also -------- Client.restart Client.restart_workers + Scheduler.restart_workers """ - stimulus_id = f"restart-{time()}" - - logger.info("Restarting workers and releasing all keys.") + logger.info(f"Restarting workers and releasing all keys ({stimulus_id=})") for cs in self.clients.values(): self.client_releases_keys( keys=[ts.key for ts in cs.wants_what], @@ -6226,19 +6218,92 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True): except Exception as e: logger.exception(e) + await self.restart_workers( + client=client, + timeout=timeout, + wait_for_workers=wait_for_workers, + stimulus_id=stimulus_id, + ) + + @log_errors + async def restart_workers( + self, + workers: list[str] | None = None, + *, + client: str | None = None, + timeout: float = 30, + wait_for_workers: bool = True, + on_error: Literal["raise", "return"] = "raise", + stimulus_id: str, + ) -> dict[str, Literal["OK", "removed", "timed out"]]: + """Restart selected workers. Optionally wait for workers to return. + + Workers without nannies are shut down, hoping an external deployment system + will restart them. Therefore, if not using nannies and your deployment system + does not automatically restart workers, ``restart`` will just shut down all + workers, then time out! + + After ``restart``, all connected workers are new, regardless of whether + ``TimeoutError`` was raised. Any workers that failed to shut down in time are + removed, and may or may not shut down on their own in the future. + + Parameters + ---------- + workers: + List of worker addresses to restart. If omitted, restart all workers. + timeout: + How long to wait for workers to shut down and come back, if ``wait_for_workers`` + is True, otherwise just how long to wait for workers to shut down. + Raises ``asyncio.TimeoutError`` if this is exceeded. + wait_for_workers: + Whether to wait for all workers to reconnect, or just for them to shut down + (default True). Use ``restart(wait_for_workers=False)`` combined with + :meth:`Client.wait_for_workers` for granular control over how many workers to + wait for. + on_error: + If 'raise' (the default), raise if any nanny times out while restarting the + worker. If 'return', return error messages. + + Returns + ------- + {worker address: "OK", "no nanny", or "timed out" or error message} + + See also + -------- + Client.restart + Client.restart_workers + Scheduler.restart + """ n_workers = len(self.workers) + if workers is None: + workers = list(self.workers) + logger.info(f"Restarting all workers ({stimulus_id=}") + else: + workers = list(set(workers).intersection(self.workers)) + logger.info(f"Restarting {len(workers)} workers: {workers} ({stimulus_id=}") + nanny_workers = { - addr: ws.nanny for addr, ws in self.workers.items() if ws.nanny + addr: self.workers[addr].nanny + for addr in workers + if self.workers[addr].nanny } - # Close non-Nanny workers. We have no way to restart them, so we just let them go, - # and assume a deployment system is going to restart them for us. - await asyncio.gather( - *( - self.remove_worker(address=addr, stimulus_id=stimulus_id) - for addr in self.workers - if addr not in nanny_workers + # Close non-Nanny workers. We have no way to restart them, so we just let them + # go, and assume a deployment system is going to restart them for us. + no_nanny_workers = [addr for addr in workers if addr not in nanny_workers] + if no_nanny_workers: + logger.warning( + f"Workers {no_nanny_workers} do not use a nanny and will be terminated " + "without restarting them" ) - ) + await asyncio.gather( + *( + self.remove_worker(address=addr, stimulus_id=stimulus_id) + for addr in no_nanny_workers + ) + ) + out: dict[str, Literal["OK", "removed", "timed out"]] + out = {addr: "removed" for addr in no_nanny_workers} + start = monotonic() logger.debug("Send kill signal to nannies: %s", nanny_workers) async with contextlib.AsyncExitStack() as stack: @@ -6250,18 +6315,13 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True): for nanny_address in nanny_workers.values() ) ) - - start = monotonic() resps = await asyncio.gather( *( wait_for( # FIXME does not raise if the process fails to shut down, # see https://github.com/dask/distributed/pull/6427/files#r894917424 # NOTE: Nanny will automatically restart worker process when it's killed - nanny.kill( - reason="scheduler-restart", - timeout=timeout, - ), + nanny.kill(reason=stimulus_id, timeout=timeout), timeout, ) for nanny in nannies @@ -6273,46 +6333,80 @@ async def restart(self, client=None, timeout=30, wait_for_workers=True): # Remove any workers that failed to shut down, so we can guarantee # that after `restart`, there are no old workers around. - bad_nannies = [ - addr for addr, resp in zip(nanny_workers, resps) if resp is not None - ] + bad_nannies = set() + for addr, resp in zip(nanny_workers, resps): + if resp is None: + out[addr] = "OK" + elif isinstance(resp, (OSError, TimeoutError)): + bad_nannies.add(addr) + out[addr] = "timed out" + else: # pragma: nocover + raise resp + if bad_nannies: + logger.error( + f"Workers {list(bad_nannies)} did not shut down within {timeout}s; " + "force closing" + ) await asyncio.gather( *( self.remove_worker(addr, stimulus_id=stimulus_id) for addr in bad_nannies ) ) + if on_error == "raise": + raise TimeoutError( + f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not " + f"shut down within {timeout}s: {bad_nannies}" + ) - raise TimeoutError( - f"{len(bad_nannies)}/{len(nannies)} nanny worker(s) did not shut down within {timeout}s" - ) + if client: + self.log_event(client, {"action": "restart-workers", "workers": workers}) + self.log_event( + "all", {"action": "restart-workers", "workers": workers, "client": client} + ) + + if not wait_for_workers: + logger.info( + "Workers restart finished (did not wait for new workers) " + f"({stimulus_id=}" + ) + return out + + # NOTE: if new (unrelated) workers join while we're waiting, we may return + # before our shut-down workers have come back up. That's fine; workers are + # interchangeable. + while monotonic() < start + timeout and len(self.workers) < n_workers: + await asyncio.sleep(0.2) + + if len(self.workers) >= n_workers: + logger.info(f"Workers restart finished ({stimulus_id=}") + return out + + msg = ( + f"Waited for {len(workers)} worker(s) to reconnect after restarting but, " + f"after {timeout}s, {n_workers - len(self.workers)} have not returned. " + "Consider a longer timeout, or `wait_for_workers=False`." + ) + if no_nanny_workers: + msg += ( + f" The {len(no_nanny_workers)} worker(s) not using Nannies were just shut " + "down instead of restarted (restart is only possible with Nannies). If " + "your deployment system does not automatically re-launch terminated " + "processes, then those workers will never come back, and `Client.restart` " + "will always time out. Do not use `Client.restart` in that case." + ) - self.log_event([client, "all"], {"action": "restart", "client": client}) + if on_error == "raise": + raise TimeoutError(msg) + logger.error(f"{msg} ({stimulus_id=})") - if wait_for_workers: - while len(self.workers) < n_workers: - # NOTE: if new (unrelated) workers join while we're waiting, we may return before - # our shut-down workers have come back up. That's fine; workers are interchangeable. - if monotonic() < start + timeout: - await asyncio.sleep(0.2) - else: - msg = ( - f"Waited for {n_workers} worker(s) to reconnect after restarting, " - f"but after {timeout}s, only {len(self.workers)} have returned. " - "Consider a longer timeout, or `wait_for_workers=False`." - ) + new_nannies = {ws.nanny for ws in self.workers.values() if ws.nanny} + for worker_addr, nanny_addr in nanny_workers.items(): + if nanny_addr not in new_nannies: + out[worker_addr] = "timed out" - if (n_nanny := len(nanny_workers)) < n_workers: - msg += ( - f" The {n_workers - n_nanny} worker(s) not using Nannies were just shut " - "down instead of restarted (restart is only possible with Nannies). If " - "your deployment system does not automatically re-launch terminated " - "processes, then those workers will never come back, and `Client.restart` " - "will always time out. Do not use `Client.restart` in that case." - ) - raise TimeoutError(msg) from None - logger.info("Restarting finished.") + return out async def broadcast( self, diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 29d7d8ce7e8..2086505cadb 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -88,6 +88,8 @@ NO_AMM, BlockedGatherDep, BlockedGetData, + BlockedKillNanny, + BlockedStartNanny, TaskStateMetadataPlugin, _UnhashableCallable, async_poll_for, @@ -3649,7 +3651,7 @@ async def test_Client_clears_references_after_restart(c, s, a, b): assert x.key in c.futures with pytest.raises(TimeoutError): - await c.restart(timeout=5) + await c.restart(timeout=1) assert x.key not in c.refcount assert not c.futures @@ -5007,8 +5009,11 @@ async def test_restart_workers(c, s, a, b): # Restart a single worker a_worker_addr = a.worker_address results = await c.restart_workers(workers=[a.worker_address]) - assert results[a_worker_addr] == "OK" - assert set(s.workers) == {a.worker_address, b.worker_address} + assert results == {a_worker_addr: "OK"} + # There can be some lag between a worker connecting to the scheduler and the + # nanny updating the worker's port + while set(s.workers) != {a.worker_address, b.worker_address}: + await asyncio.sleep(0.01) # Make sure worker start times are as expected results = await c.run(lambda dask_worker: dask_worker.start_time) @@ -5028,48 +5033,56 @@ async def test_restart_workers_no_nanny_raises(c, s, a, b): assert a.address in msg -class SlowKillNanny(Nanny): - async def kill(self, timeout=2, **kwargs): - await asyncio.sleep(2) - return await super().kill(timeout=timeout) - - @pytest.mark.slow @pytest.mark.parametrize("raise_for_error", (True, False)) -@gen_cluster(client=True, Worker=SlowKillNanny) -async def test_restart_workers_timeout(c, s, a, b, raise_for_error): +@gen_cluster(client=True, nthreads=[("", 1)], Worker=BlockedKillNanny) +async def test_restart_workers_kill_timeout(c, s, a, raise_for_error): kwargs = dict(workers=[a.worker_address], timeout=0.001) if raise_for_error: # default is to raise with pytest.raises(TimeoutError) as excinfo: await c.restart_workers(**kwargs) - msg = str(excinfo.value).lower() - assert "workers failed to restart" in msg + msg = str(excinfo.value) + assert "1/1 nanny worker(s) did not shut down within 0.001s" in msg assert a.worker_address in msg else: results = await c.restart_workers(raise_for_error=raise_for_error, **kwargs) assert results == {a.worker_address: "timed out"} + a.wait_kill.set() @pytest.mark.slow @pytest.mark.parametrize("raise_for_error", (True, False)) -@gen_cluster(client=True, Worker=SlowKillNanny) -async def test_restart_workers_exception(c, s, a, b, raise_for_error): +@gen_cluster(client=True, nthreads=[]) +async def test_restart_workers_restart_timeout(c, s, raise_for_error): + a = BlockedStartNanny(s.address) + a.wait_instantiate.set() + async with a: + a.wait_instantiate.clear() + kwargs = dict(workers=[a.worker_address], timeout=0.001) + + if raise_for_error: # default is to raise + with pytest.raises(TimeoutError) as excinfo: + await c.restart_workers(**kwargs) + msg = str(excinfo.value) + assert "1/1 nanny worker(s) did not shut down within 0.001s" in msg + assert a.worker_address in msg + else: + results = await c.restart_workers(raise_for_error=raise_for_error, **kwargs) + assert results == {a.worker_address: "timed out"} + + +@pytest.mark.slow +@gen_cluster(client=True, Worker=Nanny) +async def test_restart_workers_exception(c, s, a, b): async def fail_instantiate(*_args, **_kwargs): raise ValueError("broken") a.instantiate = fail_instantiate - if raise_for_error: # default is to raise - with pytest.raises(ValueError, match="broken"): - await c.restart_workers(workers=[a.worker_address]) - else: - results = await c.restart_workers( - workers=[a.worker_address], raise_for_error=raise_for_error - ) - msg = results[a.worker_address] - assert msg["status"] == "error" - assert msg["exception_text"] == "ValueError('broken')" + with captured_logger("distributed.nanny") as log, pytest.raises(TimeoutError): + await c.restart_workers(workers=[a.worker_address], timeout=3) + assert "broken" in log.getvalue() @pytest.mark.slow diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 8d7eec43606..7c7c667e4b7 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -156,8 +156,7 @@ async def test_restart(c, s, a, b): assert s.tasks[y.key].state == "memory" assert s.tasks[z.key].state != "memory" - f = await c.restart() - assert f is c + await c.restart() assert len(s.workers) == 2 assert not any(ws.occupancy for ws in s.workers.values()) @@ -259,7 +258,7 @@ async def test_restart_scheduler(s, a, b): assert pids[0] assert pids[1] - await s.restart() + await s.restart(stimulus_id="test") assert len(s.workers) == 2 pids2 = (a.pid, b.pid) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 74d049a35d4..ff607646d9a 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -29,6 +29,7 @@ from distributed.protocol.pickle import dumps from distributed.utils import TimeoutError, get_mp_context, parse_ports from distributed.utils_test import ( + BlockedStartNanny, async_poll_for, captured_logger, gen_cluster, @@ -217,7 +218,7 @@ async def test_nanny_timeout(c, s, a): with captured_logger( logging.getLogger("distributed.nanny"), level=logging.ERROR ) as logger: - response = await a.restart(timeout=0.1) + await a.restart(timeout=0.1) out = logger.getvalue() assert "timed out" in out.lower() @@ -846,23 +847,11 @@ def teardown(self, nanny): nanny._plugin_registered = False -class SlowNanny(Nanny): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.in_instantiate = asyncio.Event() - self.wait_instantiate = asyncio.Event() - - async def instantiate(self): - self.in_instantiate.set() - await self.wait_instantiate.wait() - return await super().instantiate() - - @pytest.mark.parametrize("restart", [True, False]) @gen_cluster(client=True, nthreads=[]) async def test_nanny_plugin_register_during_start_success(c, s, restart): plugin = DummyNannyPlugin("foo", restart=restart) - n = SlowNanny(s.address) + n = BlockedStartNanny(s.address) assert not hasattr(n, "_plugin_registered") start = asyncio.create_task(n.start()) try: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 636efcd9e48..805db3cdbb3 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1090,8 +1090,8 @@ async def test_restart(c, s, a, b): futures = c.map(inc, range(20)) await wait(futures) with captured_logger("distributed.nanny") as nanny_logger: - await s.restart() - assert "Reason: scheduler-restart" in nanny_logger.getvalue() + await s.restart(stimulus_id="test123") + assert "Reason: test123" in nanny_logger.getvalue() assert not s.computations assert not s.task_prefixes @@ -1194,7 +1194,7 @@ async def test_restart_worker_rejoins_after_timeout_expired(c, s, a): """ We don't want to see an error message like: - ``Waited for 1 worker(s) to reconnect after restarting, but after 0s, only 1 have returned.`` + ``Waited for 1 worker(s) to reconnect after restarting, but after 0s, 0 have not returned.`` If a worker rejoins after our last poll for new workers, but before we raise the error, we shouldn't raise the error. @@ -1217,7 +1217,7 @@ async def remove_worker(self, *args, **kwargs): await Plugin.removed.wait() assert not s.workers - async with Worker(s.address, nthreads=1) as w: + async with Worker(s.address, nthreads=1): assert len(s.workers) == 1 Plugin.proceed.set() @@ -1231,7 +1231,8 @@ async def test_restart_no_wait_for_workers(c, s, a, b): await c.restart(timeout="1s", wait_for_workers=False) assert not s.workers - # Workers are not immediately closed because of https://github.com/dask/distributed/issues/6390 + # Workers are not immediately closed because of + # https://github.com/dask/distributed/issues/6390 # (the message is still waiting in the BatchedSend) await a.finished() await b.finished() @@ -1268,7 +1269,7 @@ async def test_restart_heartbeat_before_closing(c, s, n): https://github.com/dask/distributed/issues/6494 """ prev_workers = dict(s.workers) - restart_task = asyncio.create_task(s.restart()) + restart_task = asyncio.create_task(s.restart(stimulus_id="test")) await n.kill_called.wait() await asyncio.sleep(0.5) # significantly longer than the heartbeat interval @@ -2573,7 +2574,7 @@ async def f(dask_worker): assert s.bandwidth_workers - await s.restart() + await s.restart(stimulus_id="test") assert not s.bandwidth_workers diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 9c167d04128..014884a17bf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2396,6 +2396,30 @@ def freeze_batched_send(bcomm: BatchedSend) -> Iterator[LockedComm]: bcomm.comm = orig_comm +class BlockedStartNanny(Nanny): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.in_instantiate = asyncio.Event() + self.wait_instantiate = asyncio.Event() + + async def instantiate(self): + self.in_instantiate.set() + await self.wait_instantiate.wait() + return await super().instantiate() + + +class BlockedKillNanny(Nanny): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.in_kill = asyncio.Event() + self.wait_kill = asyncio.Event() + + async def kill(self, **kwargs): + self.in_kill.set() + await self.wait_kill.wait() + return await super().kill(**kwargs) + + async def wait_for_state( key: Key, state: str | Collection[str],