From 9a8b380d2f4a6087c5d4cdd916fc8504e88ea227 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 6 Oct 2023 11:55:58 +0100 Subject: [PATCH] SpecCluster resilience to broken workers (#8233) --- distributed/deploy/spec.py | 26 ++++++++------ distributed/deploy/tests/test_local.py | 31 ++++++++++------ distributed/deploy/tests/test_spec_cluster.py | 12 ++----- distributed/tests/test_utils_test.py | 35 ++++++++++++++++--- distributed/utils_test.py | 20 ++++++----- 5 files changed, 80 insertions(+), 44 deletions(-) diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index de48a231ad..79e1b396c4 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -379,13 +379,15 @@ async def _correct_state_internal(self) -> None: self._created.add(worker) workers.append(worker) if workers: - await asyncio.wait( - [asyncio.create_task(_wrap_awaitable(w)) for w in workers] - ) + worker_futs = [asyncio.ensure_future(w) for w in workers] + await asyncio.wait(worker_futs) + self.workers.update(dict(zip(to_open, workers))) for w in workers: w._cluster = weakref.ref(self) - await w # for tornado gen.coroutine support - self.workers.update(dict(zip(to_open, workers))) + # Collect exceptions from failed workers. This must happen after all + # *other* workers have finished initialising, so that we can have a + # proper teardown. + await asyncio.gather(*worker_futs) def _update_worker_status(self, op, msg): if op == "remove": @@ -467,10 +469,14 @@ async def _close(self): await super()._close() async def __aenter__(self): - await self - await self._correct_state() - assert self.status == Status.running - return self + try: + await self + await self._correct_state() + assert self.status == Status.running + return self + except Exception: + await self.close() + raise def _threads_per_worker(self) -> int: """Return the number of threads per worker for new workers""" @@ -678,8 +684,6 @@ async def run_spec(spec: dict[str, Any], *args: Any) -> dict[str, Worker | Nanny if workers: await asyncio.gather(*workers.values()) - for w in workers.values(): - await w # for tornado gen.coroutine support return workers diff --git a/distributed/deploy/tests/test_local.py b/distributed/deploy/tests/test_local.py index 853900c6c5..2ddb0f0bbf 100644 --- a/distributed/deploy/tests/test_local.py +++ b/distributed/deploy/tests/test_local.py @@ -1141,16 +1141,18 @@ async def test_local_cluster_redundant_kwarg(nanny): dashboard_address=":0", asynchronous=True, ) - try: - with pytest.raises(TypeError, match="unexpected keyword argument"): - # Extra arguments are forwarded to the worker class. Depending on - # whether we use the nanny or not, the error treatment is quite - # different and we should assert that an exception is raised - async with cluster: - pass - finally: - # FIXME: LocalCluster leaks if LocalCluster.__aenter__ raises - await cluster.close() + if nanny: + ctx = raises_with_cause( + RuntimeError, None, TypeError, "unexpected keyword argument" + ) + else: + ctx = pytest.raises(TypeError, match="unexpected keyword argument") + with ctx: + # Extra arguments are forwarded to the worker class. Depending on + # whether we use the nanny or not, the error treatment is quite + # different and we should assert that an exception is raised + async with cluster: + pass @gen_test() @@ -1255,7 +1257,14 @@ def setup(self, worker=None): @pytest.mark.slow def test_localcluster_start_exception(loop): - with raises_with_cause(RuntimeError, None, ImportError, "my_nonexistent_library"): + with raises_with_cause( + RuntimeError, + "Nanny failed to start", + RuntimeError, + "Worker failed to start", + ImportError, + "my_nonexistent_library", + ): with LocalCluster( n_workers=1, threads_per_worker=1, diff --git a/distributed/deploy/tests/test_spec_cluster.py b/distributed/deploy/tests/test_spec_cluster.py index f875db0c3e..b6d5a1cb03 100644 --- a/distributed/deploy/tests/test_spec_cluster.py +++ b/distributed/deploy/tests/test_spec_cluster.py @@ -207,7 +207,6 @@ async def test_restart(): await asyncio.sleep(0.01) -@pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out") @gen_test() async def test_broken_worker(): class BrokenWorkerException(Exception): @@ -216,7 +215,6 @@ class BrokenWorkerException(Exception): class BrokenWorker(Worker): def __await__(self): async def _(): - self.status = Status.closed raise BrokenWorkerException("Worker Broken") return _().__await__() @@ -226,13 +224,9 @@ async def _(): workers={"good": {"cls": Worker}, "bad": {"cls": BrokenWorker}}, scheduler=scheduler, ) - try: - with pytest.raises(BrokenWorkerException, match=r"Worker Broken"): - async with cluster: - pass - finally: - # FIXME: SpecCluster leaks if SpecCluster.__aenter__ raises - await cluster.close() + with pytest.raises(BrokenWorkerException, match=r"Worker Broken"): + async with cluster: + pass @pytest.mark.skipif(WINDOWS, reason="HTTP Server doesn't close out") diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 0cc86118ef..167c2aba42 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -764,16 +764,16 @@ def test_raises_with_cause(): raise RuntimeError("foo") from ValueError("bar") # we're trying to stick to pytest semantics - # If the exception types don't match, raise the original exception + # If the exception types don't match, raise the first exception that doesnt' match # If the text doesn't match, raise an assert - with pytest.raises(RuntimeError): + with pytest.raises(OSError): with raises_with_cause(RuntimeError, "exception", ValueError, "cause"): raise RuntimeError("exception") from OSError("cause") - with pytest.raises(ValueError): + with pytest.raises(OSError): with raises_with_cause(RuntimeError, "exception", ValueError, "cause"): - raise ValueError("exception") from ValueError("cause") + raise OSError("exception") from ValueError("cause") with pytest.raises(AssertionError): with raises_with_cause(RuntimeError, "exception", ValueError, "foo"): @@ -783,6 +783,33 @@ def test_raises_with_cause(): with raises_with_cause(RuntimeError, "foo", ValueError, "cause"): raise RuntimeError("exception") from ValueError("cause") + # There can be more than one nested cause + with raises_with_cause( + RuntimeError, "exception", ValueError, "cause1", OSError, "cause2" + ): + try: + raise ValueError("cause1") from OSError("cause2") + except ValueError as e: + raise RuntimeError("exception") from e + + with pytest.raises(OSError): + with raises_with_cause( + RuntimeError, "exception", ValueError, "cause1", TypeError, "cause2" + ): + try: + raise ValueError("cause1") from OSError("cause2") + except ValueError as e: + raise RuntimeError("exception") from e + + with pytest.raises(AssertionError): + with raises_with_cause( + RuntimeError, "exception", ValueError, "cause1", OSError, "cause2" + ): + try: + raise ValueError("cause1") from OSError("no match") + except ValueError as e: + raise RuntimeError("exception") from e + @pytest.mark.slow def test_check_thread_leak(): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d024340a7e..a67dad3170 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -12,7 +12,6 @@ import logging.config import multiprocessing import os -import re import signal import socket import ssl @@ -2100,8 +2099,10 @@ def raises_with_cause( match: str | None, expected_cause: type[BaseException] | tuple[type[BaseException], ...], match_cause: str | None, + *more_causes: type[BaseException] | tuple[type[BaseException], ...] | str | None, ) -> Generator[None, None, None]: - """Contextmanager to assert that a certain exception with cause was raised + """Contextmanager to assert that a certain exception with cause was raised. + It can travel the causes recursively by adding more expected, match pairs at the end. Parameters ---------- @@ -2111,13 +2112,14 @@ def raises_with_cause( yield exc = exc_info.value - assert exc.__cause__ - if not isinstance(exc.__cause__, expected_cause): - raise exc - if match_cause: - assert re.search( - match_cause, str(exc.__cause__) - ), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``" + causes = [expected_cause, *more_causes[::2]] + match_causes = [match_cause, *more_causes[1::2]] + assert len(causes) == len(match_causes) + for expected_cause, match_cause in zip(causes, match_causes): # type: ignore + assert exc.__cause__ + exc = exc.__cause__ + with pytest.raises(expected_cause, match=match_cause): + raise exc def ucx_exception_handler(loop, context):