Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fail_hard decorator for worker methods #6210

Merged
merged 18 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ async def handle_stream(self, comm, extra=None, every_cycle=()):
func()

except OSError:
# FIXME: This is silently ignored, is this intentional?
pass
except Exception as e:
logger.exception(e)
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ async def test_no_delay_during_large_transfer(c, s, w):
async def test_chaos_rechunk(c, s, *workers):
s.allowed_failures = 10000

plugin = KillWorker(delay="4 s", mode="graceful")
plugin = KillWorker(delay="4 s", mode="sys.exit")

await c.register_worker_plugin(plugin, name="kill")

Expand Down
14 changes: 14 additions & 0 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,3 +679,17 @@ async def test_log_invalid_worker_task_states(c, s, a):

assert "released" in out + err
assert "task-name" in out + err


def test_worker_fail_hard(capsys):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_fail_hard(c, s, a):
with pytest.raises(Exception):
await a.gather_dep(
worker="abcd", to_gather=["x"], total_nbytes=0, stimulus_id="foo"
)

with pytest.raises(Exception) as info:
test_fail_hard()

assert "abcd" in str(info.value)
26 changes: 25 additions & 1 deletion distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@
from distributed.comm.tcp import TCP
from distributed.compatibility import WINDOWS
from distributed.config import initialize_logging
from distributed.core import CommClosedError, ConnectionPool, Status, connect, rpc
from distributed.core import (
CommClosedError,
ConnectionPool,
Status,
clean_exception,
connect,
rpc,
)
from distributed.deploy import SpecCluster
from distributed.diagnostics.plugin import WorkerPlugin
from distributed.metrics import time
from distributed.nanny import Nanny
from distributed.node import ServerNode
from distributed.proctitle import enable_proctitle_on_children
from distributed.protocol import deserialize
from distributed.security import Security
from distributed.utils import (
DequeHandler,
Expand Down Expand Up @@ -878,6 +886,7 @@ async def start_cluster(
await s.close(fast=True)
check_invalid_worker_transitions(s)
check_invalid_task_states(s)
check_worker_fail_hard(s)
raise TimeoutError("Cluster creation timeout")
return s, workers

Expand Down Expand Up @@ -909,6 +918,20 @@ def check_invalid_task_states(s: Scheduler) -> None:
raise ValueError("Invalid worker task state")


def check_worker_fail_hard(s: Scheduler) -> None:
if not s.events.get("worker-fail-hard"):
return

for timestamp, msg in s.events["worker-fail-hard"]:
msg = msg.copy()
worker = msg.pop("worker")
msg["exception"] = deserialize(msg["exception"].header, msg["exception"].frames)
msg["traceback"] = deserialize(msg["traceback"].header, msg["traceback"].frames)
print("Failed worker", worker)
typ, exc, tb = clean_exception(**msg)
raise exc.with_traceback(tb)


async def end_cluster(s, workers):
logger.debug("Closing out test cluster")

Expand All @@ -921,6 +944,7 @@ async def end_worker(w):
s.stop()
check_invalid_worker_transitions(s)
check_invalid_task_states(s)
check_worker_fail_hard(s)


def gen_cluster(
Expand Down
119 changes: 94 additions & 25 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,72 @@
DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}


def fail_hard(method):
"""
Decorator to close the worker if this method encounters an exception.
"""
if iscoroutinefunction(method):

@functools.wraps(method)
async def wrapper(self, *args, **kwargs):
try:
return await method(self, *args, **kwargs)
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except Exception as e:
except BaseException as e:

This would fix #5958 if you also wrapped execute (you should wrap execute either way I think).

The fact that BaseExceptions in callbacks aren't propagated by Tornado is pretty crazy. If we're going to add manual support for propagating exceptions like this, I don't see why we'd let BaseExceptions be ignored.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with catching BaseException for tasks, i.e. in apply_function, et al. to fix #5958

However, I would be worried to close workers upon a asyncio.CancelledError. While I don't think we're using cancellation in many places right now, this would be a very confusing behavior if that ever changes.

if self.status not in (Status.closed, Status.closing):
self.log_event(
"worker-fail-hard",
{
**error_message(e),
"worker": self.address,
},
)
logger.exception(e)
await _force_close(self)

else:

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
try:
return method(self, *args, **kwargs)
except Exception as e:
if self.status not in (Status.closed, Status.closing):
self.log_event(
"worker-fail-hard",
{
**error_message(e),
"worker": self.address,
},
)
logger.exception(e)
else:
self.loop.add_callback(_force_close, self)

return wrapper


async def _force_close(self):
"""
Used with the fail_hard decorator defined above

1. Wait for a worker to close
2. If it doesn't, log and kill the process
"""
try:
await asyncio.wait_for(self.close(nanny=False, executor_wait=False), 30)
fjetter marked this conversation as resolved.
Show resolved Hide resolved
except (Exception, BaseException): # <-- include BaseException here or not??
# Worker is in a very broken state if closing fails. We need to shut down immediately,
# to ensure things don't get even worse and this worker potentially deadlocks the cluster.
logger.critical(
"Error trying close worker in response to broken internal state. "
"Forcibly exiting worker NOW",
exc_info=True,
)
# use `os._exit` instead of `sys.exit` because of uncertainty
# around propagating `SystemExit` from asyncio callbacks
os._exit(1)


class Worker(ServerNode):
"""Worker node in a Dask distributed cluster

Expand Down Expand Up @@ -900,14 +966,15 @@ def logs(self):
return self._deque_handler.deque

def log_event(self, topic, msg):
self.loop.add_callback(
self.batched_stream.send,
{
"op": "log-event",
"topic": topic,
"msg": msg,
},
)
full_msg = {
"op": "log-event",
"topic": topic,
"msg": msg,
}
if self.thread_id == threading.get_ident():
self.batched_stream.send(full_msg)
else:
self.loop.add_callback(self.batched_stream.send, full_msg)
fjetter marked this conversation as resolved.
Show resolved Hide resolved

@property
def executing_count(self) -> int:
Expand Down Expand Up @@ -1199,22 +1266,19 @@ async def heartbeat(self):
finally:
self.heartbeat_active = False

@fail_hard
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the decorator is here, what's the point of having this block anymore?

finally:
if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
else:
logger.info(
"Connection to scheduler broken. Closing without reporting. Status: %s",
self.status,
)
await self.close(report=False)

The decorator is going to close the worker as soon as the finally block completes. So what's the point of trying to reconnect if we're going to close either way? These seem like opposing behaviors.

I think we want to either remove the try/except entirely from handle_scheduler (because all we want to do on error is close the worker, and @fail_hard will do that for us anyway), or not use @fail_hard here, if we do in fact still want to reconnect in the face of errors.

async def handle_scheduler(self, comm):
try:
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])
except Exception as e:
logger.exception(e)
raise
finally:
if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
else:
logger.info(
"Connection to scheduler broken. Closing without reporting. Status: %s",
self.status,
)
await self.close(report=False)
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])
fjetter marked this conversation as resolved.
Show resolved Hide resolved

if self.reconnect and self.status in Status.ANY_RUNNING:
logger.info("Connection to scheduler broken. Reconnecting...")
self.loop.add_callback(self.heartbeat)
else:
logger.info(
"Connection to scheduler broken. Closing without reporting. Status: %s",
self.status,
)
await self.close(report=False)
fjetter marked this conversation as resolved.
Show resolved Hide resolved

async def upload_file(self, comm, filename=None, data=None, load=True):
out_filename = os.path.join(self.local_directory, filename)
Expand Down Expand Up @@ -1654,7 +1718,9 @@ async def get_data(
assert response == "OK", response
except OSError:
logger.exception(
"failed during get data with %s -> %s", self.address, who, exc_info=True
"failed during get data with %s -> %s",
self.address,
who,
)
comm.abort()
raise
Expand Down Expand Up @@ -2682,6 +2748,7 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None:
self.transitions(recs, stimulus_id=stim.stimulus_id)
self._handle_instructions(instructions)

@fail_hard
@log_errors
def _handle_stimulus_from_task(
self, task: asyncio.Task[StateMachineEvent | None]
Expand All @@ -2695,6 +2762,7 @@ def _handle_stimulus_from_task(
if stim:
self.handle_stimulus(stim)

@fail_hard
def _handle_instructions(self, instructions: Instructions) -> None:
# TODO this method is temporary.
# See final design: https://github.com/dask/distributed/issues/5894
Expand Down Expand Up @@ -3023,6 +3091,7 @@ def _update_metrics_received_data(
self.counters["transfer-count"].add(len(data))
self.incoming_count += 1

@fail_hard
@log_errors
async def gather_dep(
self,
Expand Down Expand Up @@ -3548,6 +3617,7 @@ def _ensure_computing(self) -> RecsInstrs:

return recs, []

@fail_hard
async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | None:
if self.status in {Status.closing, Status.closed, Status.closing_gracefully}:
return None
Expand Down Expand Up @@ -4107,7 +4177,6 @@ def validate_state(self):

except Exception as e:
logger.error("Validate state failed. Closing.", exc_info=e)
self.loop.add_callback(self.close)
logger.exception(e)
if LOG_PDB:
import pdb
Expand Down