Skip to content

Commit

Permalink
Restart workers when worker-ttl expires
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 5, 2024
1 parent 7680b85 commit e312553
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 47 deletions.
13 changes: 5 additions & 8 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,14 @@ async def start_unsafe(self):

return self

async def kill(self, timeout: float = 2, reason: str = "nanny-kill") -> None:
async def kill(self, timeout: float = 5, reason: str = "nanny-kill") -> None:
"""Kill the local worker process
Blocks until both the process is down and the scheduler is properly
informed
"""
if self.process is None:
return

deadline = time() + timeout
await self.process.kill(reason=reason, timeout=0.8 * (deadline - time()))
if self.process is not None:
await self.process.kill(reason=reason, timeout=0.8 * timeout)

async def instantiate(self) -> Status:
"""Start a local worker process
Expand Down Expand Up @@ -822,7 +819,7 @@ def mark_stopped(self):

async def kill(
self,
timeout: float = 2,
timeout: float = 5,
executor_wait: bool = True,
reason: str = "workerprocess-kill",
) -> None:
Expand Down Expand Up @@ -876,7 +873,7 @@ async def kill(
pass

logger.warning(
f"Worker process still alive after {wait_timeout} seconds, killing"
f"Worker process still alive after {wait_timeout:.1f} seconds, killing"
)
await process.kill()
await process.join(max(0, deadline - time()))
Expand Down
33 changes: 23 additions & 10 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def __init__(
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
self.last_seen = 0
self.last_seen = time()
self.time_delay = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.actors = set()
Expand Down Expand Up @@ -6321,7 +6321,10 @@ async def restart_workers(
# FIXME does not raise if the process fails to shut down,
# see https://github.com/dask/distributed/pull/6427/files#r894917424
# NOTE: Nanny will automatically restart worker process when it's killed
nanny.kill(reason=stimulus_id, timeout=timeout),
# NOTE: Don't propagate timeout to kill(): we don't want to
# spend (.8*.8)=64% of our end-to-end timeout waiting for a hung
# process to restart.
nanny.kill(reason=stimulus_id),
timeout,
)
for nanny in nannies
Expand Down Expand Up @@ -8391,19 +8394,29 @@ async def get_worker_monitor_info(self, recent=False, starts=None):
# Cleanup #
###########

async def check_worker_ttl(self):
@log_errors
async def check_worker_ttl(self) -> None:
now = time()
stimulus_id = f"check-worker-ttl-{now}"
assert self.worker_ttl
ttl = max(self.worker_ttl, 10 * heartbeat_interval(len(self.workers)))
to_restart = []

for ws in self.workers.values():
if (ws.last_seen < now - self.worker_ttl) and (
ws.last_seen < now - 10 * heartbeat_interval(len(self.workers))
):
last_seen = now - ws.last_seen
if last_seen > ttl:
to_restart.append(ws.address)
logger.warning(
"Worker failed to heartbeat within %s seconds. Closing: %s",
self.worker_ttl,
ws,
f"Worker failed to heartbeat for {last_seen:.0f}s; "
f"{'attempting restart' if ws.nanny else 'removing'}: {ws}"
)
await self.remove_worker(address=ws.address, stimulus_id=stimulus_id)

if to_restart:
await self.restart_workers(
to_restart,
wait_for_workers=False,
stimulus_id=stimulus_id,
)

def check_idle(self) -> float | None:
if self.status in (Status.closing, Status.closed):
Expand Down
1 change: 0 additions & 1 deletion distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,6 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b
@gen_cluster(
client=True,
config={
"distributed.scheduler.worker-ttl": "500ms",
"distributed.scheduler.active-memory-manager.start": True,
"distributed.scheduler.active-memory-manager.interval": 0.05,
"distributed.scheduler.active-memory-manager.measure": "managed",
Expand Down
87 changes: 81 additions & 6 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from dask import delayed
from dask.utils import parse_bytes

from distributed import Client, Nanny, profile, wait
from distributed import Client, KilledWorker, Nanny, get_worker, profile, wait
from distributed.comm import CommClosedError
from distributed.compatibility import MACOS
from distributed.core import Status
from distributed.metrics import time
from distributed.utils import CancelledError, sync
from distributed.utils_test import (
Expand Down Expand Up @@ -450,10 +451,10 @@ async def test_restart_timeout_on_long_running_task(c, s, a):


@pytest.mark.slow
@gen_cluster(client=True, scheduler_kwargs={"worker_ttl": "500ms"})
@gen_cluster(client=True, config={"distributed.scheduler.worker-ttl": "500ms"})
async def test_worker_time_to_live(c, s, a, b):
from distributed.scheduler import heartbeat_interval

# Note that this value is ignored because is less than 10x heartbeat_interval
assert s.worker_ttl == 0.5
assert set(s.workers) == {a.address, b.address}

a.periodic_callbacks["heartbeat"].stop()
Expand All @@ -465,10 +466,84 @@ async def test_worker_time_to_live(c, s, a, b):

# Worker removal is triggered after 10 * heartbeat
# This is 10 * 0.5s at the moment of writing.
interval = 10 * heartbeat_interval(len(s.workers))
# Currently observing an extra 0.3~0.6s on top of the interval.
# Adding some padding to prevent flakiness.
assert time() - start < interval + 2.0
assert time() - start < 7


@pytest.mark.slow
@pytest.mark.parametrize("block_evloop", [False, True])
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 1)],
scheduler_kwargs={"worker_ttl": "500ms", "allowed_failures": 0},
)
async def test_worker_ttl_restarts_worker(c, s, a, block_evloop):
"""If the event loop of a worker becomes completely unresponsive, the scheduler will
restart it through the nanny.
"""
ws = s.workers[a.worker_address]

async def f():
w = get_worker()
w.periodic_callbacks["heartbeat"].stop()
if block_evloop:
sleep(9999) # Block event loop indefinitely
else:
await asyncio.sleep(9999)

fut = c.submit(f, key="x")

while not s.workers or (
(new_ws := next(iter(s.workers.values()))) is ws
or new_ws.status != Status.running
):
await asyncio.sleep(0.01)

if block_evloop:
# The nanny killed the worker with SIGKILL.
# The restart has increased the suspicious count.
with pytest.raises(KilledWorker):
await fut
assert s.tasks["x"].state == "erred"
assert s.tasks["x"].suspicious == 1
else:
# The nanny sent to the WorkerProcess a {op: stop} through IPC, which in turn
# successfully invoked Worker.close(nanny=False).
# This behaviour makes sense as the worker-ttl timeout was most likely caused
# by a failure in networking, rather than a hung process.
assert s.tasks["x"].state == "processing"
assert s.tasks["x"].suspicious == 0


@pytest.mark.slow
@gen_cluster(
client=True,
Worker=Nanny,
nthreads=[("", 2)],
scheduler_kwargs={"allowed_failures": 0},
)
async def test_restart_hung_worker(c, s, a):
"""Test restart_workers() to restart a worker whose event loop has become completely
unresponsive.
"""
ws = s.workers[a.worker_address]

async def f():
w = get_worker()
w.periodic_callbacks["heartbeat"].stop()
sleep(9999) # Block event loop indefinitely

fut = c.submit(f)
# Wait for worker to hang
with pytest.raises(asyncio.TimeoutError):
while True:
await wait(c.submit(inc, 1, pure=False), timeout=0.2)

await c.restart_workers([a.worker_address])
assert len(s.workers) == 1
assert next(iter(s.workers.values())) is not ws


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down
31 changes: 9 additions & 22 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NO_AMM,
BlockedGatherDep,
BlockedGetData,
BlockedKillNanny,
BrokenComm,
NoSchedulerDelayWorker,
assert_story,
Expand Down Expand Up @@ -1132,22 +1133,8 @@ async def test_restart_waits_for_new_workers(c, s, *workers):
assert set(s.workers.values()).isdisjoint(original_workers.values())


class SlowKillNanny(Nanny):
def __init__(self, *args, **kwargs):
self.kill_proceed = asyncio.Event()
self.kill_called = asyncio.Event()
super().__init__(*args, **kwargs)

async def kill(self, *, timeout, reason=None):
self.kill_called.set()
print("kill called")
await wait_for(self.kill_proceed.wait(), timeout)
print("kill proceed")
return await super().kill(timeout=timeout, reason=reason)


@pytest.mark.slow
@gen_cluster(client=True, Worker=SlowKillNanny, nthreads=[("", 1)] * 2)
@gen_cluster(client=True, Worker=BlockedKillNanny, nthreads=[("", 1)] * 2)
async def test_restart_nanny_timeout_exceeded(c, s, a, b):
try:
f = c.submit(div, 1, 0)
Expand All @@ -1162,8 +1149,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b):
TimeoutError, match=r"2/2 nanny worker\(s\) did not shut down within 1s"
):
await c.restart(timeout="1s")
assert a.kill_called.is_set()
assert b.kill_called.is_set()
assert a.in_kill.is_set()
assert b.in_kill.is_set()

assert not s.workers
assert not s.erred_tasks
Expand All @@ -1175,8 +1162,8 @@ async def test_restart_nanny_timeout_exceeded(c, s, a, b):
assert f.status == "cancelled"
assert fr.status == "cancelled"
finally:
a.kill_proceed.set()
b.kill_proceed.set()
a.wait_kill.set()
b.wait_kill.set()


@gen_cluster(client=True, nthreads=[("", 1)] * 2)
Expand Down Expand Up @@ -1260,7 +1247,7 @@ async def test_restart_some_nannies_some_not(c, s, a, b):
@gen_cluster(
client=True,
nthreads=[("", 1)],
Worker=SlowKillNanny,
Worker=BlockedKillNanny,
worker_kwargs={"heartbeat_interval": "1ms"},
)
async def test_restart_heartbeat_before_closing(c, s, n):
Expand All @@ -1271,13 +1258,13 @@ async def test_restart_heartbeat_before_closing(c, s, n):
prev_workers = dict(s.workers)
restart_task = asyncio.create_task(s.restart(stimulus_id="test"))

await n.kill_called.wait()
await n.in_kill.wait()
await asyncio.sleep(0.5) # significantly longer than the heartbeat interval

# WorkerState should not be removed yet, because the worker hasn't been told to close
assert s.workers

n.kill_proceed.set()
n.wait_kill.set()
# Wait until the worker has left (possibly until it's come back too)
while s.workers == prev_workers:
await asyncio.sleep(0.01)
Expand Down

0 comments on commit e312553

Please sign in to comment.