Skip to content

Commit

Permalink
Refactor restart() and restart_workers() (#8550)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Mar 21, 2024
1 parent 8927bfd commit 84213ac
Show file tree
Hide file tree
Showing 9 changed files with 301 additions and 178 deletions.
80 changes: 43 additions & 37 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import SubgraphCallable
from dask.typing import no_default
from dask.typing import NoDefault, no_default
from dask.utils import (
apply,
ensure_dict,
Expand All @@ -49,7 +49,7 @@
)
from dask.widgets import get_template

from distributed.core import ErrorMessage, OKMessage
from distributed.core import OKMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for

Expand Down Expand Up @@ -859,8 +859,9 @@ def __init__(
):
if timeout is no_default:
timeout = dask.config.get("distributed.comm.timeouts.connect")
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
timeout = parse_timedelta(timeout, "s")
if timeout is None:
raise ValueError("None is an invalid value for global client timeout")
self._timeout = timeout

self.futures = dict()
Expand Down Expand Up @@ -1253,8 +1254,7 @@ async def _start(self, timeout=no_default, **kwargs):

if timeout is no_default:
timeout = self._timeout
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
timeout = parse_timedelta(timeout, "s")

address = self._start_arg
if self.cluster is not None:
Expand Down Expand Up @@ -3596,16 +3596,24 @@ def persist(
else:
return result

async def _restart(self, timeout=no_default, wait_for_workers=True):
async def _restart(
self, timeout: str | int | float | NoDefault, wait_for_workers: bool
) -> None:
if timeout is no_default:
timeout = self._timeout * 4
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
timeout = parse_timedelta(cast("str|int|float", timeout), "s")

await self.scheduler.restart(timeout=timeout, wait_for_workers=wait_for_workers)
return self
await self.scheduler.restart(
timeout=timeout,
wait_for_workers=wait_for_workers,
stimulus_id=f"client-restart-{time()}",
)

def restart(self, timeout=no_default, wait_for_workers=True):
def restart(
self,
timeout: str | int | float | NoDefault = no_default,
wait_for_workers: bool = True,
):
"""
Restart all workers. Reset local state. Optionally wait for workers to return.
Expand Down Expand Up @@ -3642,46 +3650,43 @@ def restart(self, timeout=no_default, wait_for_workers=True):
async def _restart_workers(
self,
workers: list[str],
timeout: int | float | None = None,
raise_for_error: bool = True,
) -> dict[str, str | ErrorMessage]:
timeout: str | int | float | NoDefault,
raise_for_error: bool,
) -> dict[str, Literal["OK", "removed", "timed out"]]:
if timeout is no_default:
timeout = self._timeout * 4
timeout = parse_timedelta(cast("str|int|float", timeout), "s")

info = self.scheduler_info()
name_to_addr = {meta["name"]: addr for addr, meta in info["workers"].items()}
worker_addrs = [name_to_addr.get(w, w) for w in workers]

restart_out: dict[str, str | ErrorMessage] = await self.scheduler.broadcast(
msg={"op": "restart", "timeout": timeout},
out: dict[
str, Literal["OK", "removed", "timed out"]
] = await self.scheduler.restart_workers(
workers=worker_addrs,
nanny=True,
timeout=timeout,
on_error="raise" if raise_for_error else "return",
stimulus_id=f"client-restart-workers-{time()}",
)

# Map keys back to original `workers` input names/addresses
results = {w: restart_out[w_addr] for w, w_addr in zip(workers, worker_addrs)}

timeout_workers = [w for w, status in results.items() if status == "timed out"]
if timeout_workers and raise_for_error:
raise TimeoutError(
f"The following workers failed to restart with {timeout} seconds: {timeout_workers}"
)

errored: list[ErrorMessage] = [m for m in results.values() if "exception" in m] # type: ignore
if errored and raise_for_error:
raise pickle.loads(errored[0]["exception"]) # type: ignore
return results
out = {w: out[w_addr] for w, w_addr in zip(workers, worker_addrs)}
if raise_for_error:
assert all(v == "OK" for v in out.values()), out
return out

def restart_workers(
self,
workers: list[str],
timeout: int | float | None = None,
timeout: str | int | float | NoDefault = no_default,
raise_for_error: bool = True,
) -> dict[str, str]:
):
"""Restart a specified set of workers
.. note::
Only workers being monitored by a :class:`distributed.Nanny` can be restarted.
See ``Nanny.restart`` for more details.
See ``Nanny.restart`` for more details.
Parameters
----------
Expand All @@ -3696,7 +3701,7 @@ def restart_workers(
Returns
-------
dict[str, str]
dict[str, "OK" | "removed" | "timed out"]
Mapping of worker and restart status, the keys will match the original
values passed in via ``workers``.
Expand Down Expand Up @@ -3730,7 +3735,8 @@ def restart_workers(
for worker, meta in info["workers"].items():
if (worker in workers or meta["name"] in workers) and meta["nanny"] is None:
raise ValueError(
f"Restarting workers requires a nanny to be used. Worker {worker} has type {info['workers'][worker]['type']}."
f"Restarting workers requires a nanny to be used. Worker "
f"{worker} has type {info['workers'][worker]['type']}."
)
return self.sync(
self._restart_workers,
Expand Down
4 changes: 2 additions & 2 deletions distributed/diagnostics/tests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def f(x):
await wait([future])
assert p.state["memory"] == {"f": {future.key}}

await c._restart()
await c.restart()

for coll in [p.all] + list(p.state.values()):
assert not coll
Expand Down Expand Up @@ -262,7 +262,7 @@ async def test_group_timing(c, s, a, b):
]
)

await s.restart()
await s.restart(stimulus_id="test")
assert len(p.time) == 2
assert len(p.nthreads) == 2
assert len(p.compute) == 0
1 change: 1 addition & 0 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ async def kill(
assert self.status in (
Status.running,
Status.failed, # process failed to start, but hasn't been joined yet
Status.closing_gracefully,
), self.status
self.status = Status.stopping
logger.info("Nanny asking worker to close. Reason: %s", reason)
Expand Down
Loading

0 comments on commit 84213ac

Please sign in to comment.