From 587d472cef9286b5ba8159d7e4f70ebf8ba8e95f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 15 Apr 2024 18:02:48 +0100 Subject: [PATCH] Restart workers when worker-ttl expires (#8538) --- distributed/nanny.py | 13 ++- distributed/scheduler.py | 33 ++++--- .../tests/test_active_memory_manager.py | 1 - distributed/tests/test_failed_workers.py | 87 +++++++++++++++++-- 4 files changed, 109 insertions(+), 25 deletions(-) diff --git a/distributed/nanny.py b/distributed/nanny.py index 99644e9292..52e4ad5b36 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -390,17 +390,14 @@ async def start_unsafe(self): return self - async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None: + async def kill(self, timeout: float = 5, reason: str = "nanny-kill") -> None: """Kill the local worker process Blocks until both the process is down and the scheduler is properly informed """ - if self.process is None: - return - - deadline = time() + timeout - await self.process.kill(reason=reason, timeout=0.8 * (deadline - time())) + if self.process is not None: + await self.process.kill(reason=reason, timeout=timeout) async def instantiate(self) -> Status: """Start a local worker process @@ -822,7 +819,7 @@ def mark_stopped(self): async def kill( self, - timeout: float = 2, + timeout: float = 5, executor_wait: bool = True, reason: str = "workerprocess-kill", ) -> None: @@ -876,7 +873,7 @@ async def kill( pass logger.warning( - f"Worker process still alive after {wait_timeout} seconds, killing" + f"Worker process still alive after {wait_timeout:.1f} seconds, killing" ) await process.kill() await process.join(max(0, deadline - time())) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dc87640677..5740ac4414 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -547,7 +547,7 @@ def __init__( self._memory_unmanaged_old = 0 self._memory_unmanaged_history = deque() self.metrics = {} - self.last_seen = 0 + self.last_seen = time() self.time_delay = 0 self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.actors = set() @@ -6335,7 +6335,10 @@ async def restart_workers( # FIXME does not raise if the process fails to shut down, # see https://github.com/dask/distributed/pull/6427/files#r894917424 # NOTE: Nanny will automatically restart worker process when it's killed - nanny.kill(reason=stimulus_id, timeout=timeout), + # NOTE: Don't propagate timeout to kill(): we don't want to + # spend (.8*.8)=64% of our end-to-end timeout waiting for a hung + # process to restart. + nanny.kill(reason=stimulus_id), timeout, ) for nanny in nannies @@ -8406,19 +8409,29 @@ async def get_worker_monitor_info(self, recent=False, starts=None): # Cleanup # ########### - async def check_worker_ttl(self): + @log_errors + async def check_worker_ttl(self) -> None: now = time() stimulus_id = f"check-worker-ttl-{now}" + assert self.worker_ttl + ttl = max(self.worker_ttl, 10 * heartbeat_interval(len(self.workers))) + to_restart = [] + for ws in self.workers.values(): - if (ws.last_seen < now - self.worker_ttl) and ( - ws.last_seen < now - 10 * heartbeat_interval(len(self.workers)) - ): + last_seen = now - ws.last_seen + if last_seen > ttl: + to_restart.append(ws.address) logger.warning( - "Worker failed to heartbeat within %s seconds. Closing: %s", - self.worker_ttl, - ws, + f"Worker failed to heartbeat for {last_seen:.0f}s; " + f"{'attempting restart' if ws.nanny else 'removing'}: {ws}" ) - await self.remove_worker(address=ws.address, stimulus_id=stimulus_id) + + if to_restart: + await self.restart_workers( + to_restart, + wait_for_workers=False, + stimulus_id=stimulus_id, + ) def check_idle(self) -> float | None: if self.status in (Status.closing, Status.closed): diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 895e18cd34..eb13356836 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -1084,7 +1084,6 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b @gen_cluster( client=True, config={ - "distributed.scheduler.worker-ttl": "500ms", "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.05, "distributed.scheduler.active-memory-manager.measure": "managed", diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 7c7c667e4b..b62e40b8e4 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -14,9 +14,10 @@ from dask import delayed from dask.utils import parse_bytes -from distributed import Client, Nanny, profile, wait +from distributed import Client, KilledWorker, Nanny, get_worker, profile, wait from distributed.comm import CommClosedError from distributed.compatibility import MACOS +from distributed.core import Status from distributed.metrics import time from distributed.utils import CancelledError, sync from distributed.utils_test import ( @@ -450,10 +451,10 @@ async def test_restart_timeout_on_long_running_task(c, s, a): @pytest.mark.slow -@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"}) +@gen_cluster(client=True, config={"distributed.scheduler.worker-ttl": "500ms"}) async def test_worker_time_to_live(c, s, a, b): - from distributed.scheduler import heartbeat_interval - + # Note that this value is ignored because is less than 10x heartbeat_interval + assert s.worker_ttl == 0.5 assert set(s.workers) == {a.address, b.address} a.periodic_callbacks["heartbeat"].stop() @@ -465,10 +466,84 @@ async def test_worker_time_to_live(c, s, a, b): # Worker removal is triggered after 10 * heartbeat # This is 10 * 0.5s at the moment of writing. - interval = 10 * heartbeat_interval(len(s.workers)) # Currently observing an extra 0.3~0.6s on top of the interval. # Adding some padding to prevent flakiness. - assert time() - start < interval + 2.0 + assert time() - start < 7 + + +@pytest.mark.slow +@pytest.mark.parametrize("block_evloop", [False, True]) +@gen_cluster( + client=True, + Worker=Nanny, + nthreads=[("", 1)], + scheduler_kwargs={"worker_ttl": "500ms", "allowed_failures": 0}, +) +async def test_worker_ttl_restarts_worker(c, s, a, block_evloop): + """If the event loop of a worker becomes completely unresponsive, the scheduler will + restart it through the nanny. + """ + ws = s.workers[a.worker_address] + + async def f(): + w = get_worker() + w.periodic_callbacks["heartbeat"].stop() + if block_evloop: + sleep(9999) # Block event loop indefinitely + else: + await asyncio.sleep(9999) + + fut = c.submit(f, key="x") + + while not s.workers or ( + (new_ws := next(iter(s.workers.values()))) is ws + or new_ws.status != Status.running + ): + await asyncio.sleep(0.01) + + if block_evloop: + # The nanny killed the worker with SIGKILL. + # The restart has increased the suspicious count. + with pytest.raises(KilledWorker): + await fut + assert s.tasks["x"].state == "erred" + assert s.tasks["x"].suspicious == 1 + else: + # The nanny sent to the WorkerProcess a {op: stop} through IPC, which in turn + # successfully invoked Worker.close(nanny=False). + # This behaviour makes sense as the worker-ttl timeout was most likely caused + # by a failure in networking, rather than a hung process. + assert s.tasks["x"].state == "processing" + assert s.tasks["x"].suspicious == 0 + + +@pytest.mark.slow +@gen_cluster( + client=True, + Worker=Nanny, + nthreads=[("", 2)], + scheduler_kwargs={"allowed_failures": 0}, +) +async def test_restart_hung_worker(c, s, a): + """Test restart_workers() to restart a worker whose event loop has become completely + unresponsive. + """ + ws = s.workers[a.worker_address] + + async def f(): + w = get_worker() + w.periodic_callbacks["heartbeat"].stop() + sleep(9999) # Block event loop indefinitely + + fut = c.submit(f) + # Wait for worker to hang + with pytest.raises(asyncio.TimeoutError): + while True: + await wait(c.submit(inc, 1, pure=False), timeout=0.2) + + await c.restart_workers([a.worker_address]) + assert len(s.workers) == 1 + assert next(iter(s.workers.values())) is not ws @gen_cluster(client=True, nthreads=[("", 1)])