Skip to content

Commit

Permalink
Add fail_hard decorator for worker methods (#6210)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin authored Apr 29, 2022
1 parent a39bd15 commit be45ba2
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 28 deletions.
1 change: 0 additions & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
func()

except OSError:
# FIXME: This is silently ignored, is this intentional?
pass
except Exception as e:
logger.exception(e)
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ async def test_no_delay_during_large_transfer(c, s, w):
async def test_chaos_rechunk(c, s, *workers):
s.allowed_failures = 10000

plugin = KillWorker(delay="4 s", mode="graceful")
plugin = KillWorker(delay="4 s", mode="sys.exit")

await c.register_worker_plugin(plugin, name="kill")

Expand Down
14 changes: 14 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,17 @@ async def test_log_invalid_worker_task_states(c, s, a):

assert "released" in out + err
assert "task-name" in out + err


def test_worker_fail_hard(capsys):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_fail_hard(c, s, a):
with pytest.raises(Exception):
await a.gather_dep(
worker="abcd", to_gather=["x"], total_nbytes=0, stimulus_id="foo"
)

with pytest.raises(Exception) as info:
test_fail_hard()

assert "abcd" in str(info.value)
26 changes: 25 additions & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@
from distributed.comm.tcp import TCP
from distributed.compatibility import WINDOWS
from distributed.config import initialize_logging
from distributed.core import CommClosedError, ConnectionPool, Status, connect, rpc
from distributed.core import (
CommClosedError,
ConnectionPool,
Status,
clean_exception,
connect,
rpc,
)
from distributed.deploy import SpecCluster
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.metrics import time
from distributed.nanny import Nanny
from distributed.node import ServerNode
from distributed.proctitle import enable_proctitle_on_children
from distributed.protocol import deserialize
from distributed.security import Security
from distributed.utils import (
DequeHandler,
Expand Down Expand Up @@ -878,6 +886,7 @@ async def start_cluster(
await s.close(fast=True)
check_invalid_worker_transitions(s)
check_invalid_task_states(s)
check_worker_fail_hard(s)
raise TimeoutError("Cluster creation timeout")
return s, workers

Expand Down Expand Up @@ -909,6 +918,20 @@ def check_invalid_task_states(s: Scheduler) -> None:
raise ValueError("Invalid worker task state")


def check_worker_fail_hard(s: Scheduler) -> None:
if not s.events.get("worker-fail-hard"):
return

for timestamp, msg in s.events["worker-fail-hard"]:
msg = msg.copy()
worker = msg.pop("worker")
msg["exception"] = deserialize(msg["exception"].header, msg["exception"].frames)
msg["traceback"] = deserialize(msg["traceback"].header, msg["traceback"].frames)
print("Failed worker", worker)
typ, exc, tb = clean_exception(**msg)
raise exc.with_traceback(tb)


async def end_cluster(s, workers):
logger.debug("Closing out test cluster")

Expand All @@ -921,6 +944,7 @@ async def end_worker(w):
s.stop()
check_invalid_worker_transitions(s)
check_invalid_task_states(s)
check_worker_fail_hard(s)


def gen_cluster(
Expand Down
119 changes: 94 additions & 25 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,72 @@
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}


def fail_hard(method):
"""
Decorator to close the worker if this method encounters an exception.
"""
if iscoroutinefunction(method):

@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
try:
return await method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event(
"worker-fail-hard",
{
**error_message(e),
"worker": self.address,
},
)
logger.exception(e)
await _force_close(self)

else:

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event(
"worker-fail-hard",
{
**error_message(e),
"worker": self.address,
},
)
logger.exception(e)
else:
self.loop.add_callback(_force_close, self)

return wrapper


async def _force_close(self):
"""
Used with the fail_hard decorator defined above
1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await asyncio.wait_for(self.close(nanny=False, executor_wait=False), 30)
except (Exception, BaseException): # <-- include BaseException here or not??
# Worker is in a very broken state if closing fails. We need to shut down immediately,
# to ensure things don't get even worse and this worker potentially deadlocks the cluster.
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)


class Worker(ServerNode):
"""Worker node in a Dask distributed cluster
Expand Down Expand Up @@ -900,14 +966,15 @@ def logs(self):
return self._deque_handler.deque

def log_event(self, topic, msg):
self.loop.add_callback(
self.batched_stream.send,
{
"op": "log-event",
"topic": topic,
"msg": msg,
},
)
full_msg = {
"op": "log-event",
"topic": topic,
"msg": msg,
}
if self.thread_id == threading.get_ident():
self.batched_stream.send(full_msg)
else:
self.loop.add_callback(self.batched_stream.send, full_msg)

@property
def executing_count(self) -> int:
Expand Down Expand Up @@ -1199,22 +1266,19 @@ async def heartbeat(self):
finally:
self.heartbeat_active = False

@fail_hard
async def handle_scheduler(self, comm):
try:
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])
except Exception as e:
logger.exception(e)
raise
finally:
if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
else:
logger.info(
"Connection to scheduler broken. Closing without reporting. Status: %s",
self.status,
)
await self.close(report=False)
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])

if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
else:
logger.info(
"Connection to scheduler broken. Closing without reporting. Status: %s",
self.status,
)
await self.close(report=False)

async def upload_file(self, comm, filename=None, data=None, load=True):
out_filename = os.path.join(self.local_directory, filename)
Expand Down Expand Up @@ -1654,7 +1718,9 @@ async def get_data(
assert response == "OK", response
except OSError:
logger.exception(
"failed during get data with %s -> %s", self.address, who, exc_info=True
"failed during get data with %s -> %s",
self.address,
who,
)
comm.abort()
raise
Expand Down Expand Up @@ -2682,6 +2748,7 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)

@fail_hard
@log_errors
def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
Expand All @@ -2695,6 +2762,7 @@ def _handle_stimulus_from_task(
if stim:
self.handle_stimulus(stim)

@fail_hard
def _handle_instructions(self, instructions: Instructions) -> None:
# TODO this method is temporary.
# See final design: https://github.com/dask/distributed/issues/5894
Expand Down Expand Up @@ -3023,6 +3091,7 @@ def _update_metrics_received_data(
self.counters["transfer-count"].add(len(data))
self.incoming_count += 1

@fail_hard
@log_errors
async def gather_dep(
self,
Expand Down Expand Up @@ -3548,6 +3617,7 @@ def _ensure_computing(self) -> RecsInstrs:

return recs, []

@fail_hard
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return None
Expand Down Expand Up @@ -4107,7 +4177,6 @@ def validate_state(self):

except Exception as e:
logger.error("Validate state failed. Closing.", exc_info=e)
self.loop.add_callback(self.close)
logger.exception(e)
if LOG_PDB:
import pdb
Expand Down

0 comments on commit be45ba2

Please sign in to comment.