From e0a7525e14e2ab6f555d58b2d3b53372e930b259 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Fri, 12 Apr 2024 11:18:10 +0200 Subject: [PATCH] ensure workers are not downscaled when participating in p2p (#8610) --- distributed/diagnostics/plugin.py | 24 +++++++++++ distributed/scheduler.py | 3 ++ distributed/shuffle/_scheduler_plugin.py | 8 ++++ distributed/shuffle/tests/test_shuffle.py | 49 +++++++++++++++++++++-- 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 3a866facc36..cd935140f7a 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -23,6 +23,7 @@ # circular imports from distributed.scheduler import Scheduler from distributed.scheduler import TaskStateState as SchedulerTaskStateState + from distributed.scheduler import WorkerState from distributed.worker import Worker from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState @@ -205,6 +206,29 @@ def add_client(self, scheduler: Scheduler, client: str) -> None: def remove_client(self, scheduler: Scheduler, client: str) -> None: """Run when a client disconnects""" + def valid_workers_downscaling( + self, scheduler: Scheduler, workers: list[WorkerState] + ) -> list[WorkerState]: + """Determine which workers can be removed from the cluster + + This method is called when the scheduler is about to downscale the cluster + by removing workers. The method should return a set of worker states that + can be removed from the cluster. + + Parameters + ---------- + workers : list + The list of worker states that are candidates for removal. + stimulus_id : str + ID of stimulus causing the downscaling. + + Returns + ------- + list + The list of worker states that can be removed from the cluster. + """ + return workers + def log_event(self, topic: str, msg: Any) -> None: """Run when an event is logged""" diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0251e0c310d..090e17be03b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7153,6 +7153,9 @@ def workers_to_close( # running on, as it would cause them to restart from scratch # somewhere else. valid_workers = [ws for ws in self.workers.values() if not ws.long_running] + for plugin in list(self.plugins.values()): + valid_workers = plugin.valid_workers_downscaling(self, valid_workers) + groups = groupby(key, valid_workers) limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()} diff --git a/distributed/shuffle/_scheduler_plugin.py b/distributed/shuffle/_scheduler_plugin.py index 09d97fffc9a..ef646bcea0d 100644 --- a/distributed/shuffle/_scheduler_plugin.py +++ b/distributed/shuffle/_scheduler_plugin.py @@ -410,6 +410,14 @@ def transition( if not archived: del self._archived_by_stimulus[shuffle._archived_by] + def valid_workers_downscaling( + self, scheduler: Scheduler, workers: list[WorkerState] + ) -> list[WorkerState]: + all_participating_workers = set() + for shuffle in self.active_shuffles.values(): + all_participating_workers.update(shuffle.participating_workers) + return [w for w in workers if w.address not in all_participating_workers] + def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None: worker_msgs = { worker: [ diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index c4e70e63d2f..be80107b50d 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -18,6 +18,7 @@ from packaging.version import parse as parse_version from tornado.ioloop import IOLoop +import dask from dask.utils import key_split from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key @@ -28,7 +29,6 @@ import numpy as np import pandas as pd -import dask from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200 from dask.typing import Key @@ -139,9 +139,10 @@ async def test_minimal_version(c, s, a, b): dtypes={"x": float, "y": float}, freq="10 s", ) - with pytest.raises( - ModuleNotFoundError, match="requires pyarrow" - ), dask.config.set({"dataframe.shuffle.method": "p2p"}): + with ( + pytest.raises(ModuleNotFoundError, match="requires pyarrow"), + dask.config.set({"dataframe.shuffle.method": "p2p"}), + ): await c.compute(df.shuffle("x")) @@ -2795,3 +2796,43 @@ def data_gen(): "meta", ): await c.gather(c.compute(ddf.shuffle(on="a"))) + + +@gen_cluster(client=True) +async def test_dont_downscale_participating_workers(c, s, a, b): + df = dask.datasets.timeseries( + start="2000-01-01", + end="2000-01-10", + dtypes={"x": float, "y": float}, + freq="10 s", + ) + with dask.config.set({"dataframe.shuffle.method": "p2p"}): + shuffled = df.shuffle("x") + + workers_to_close = s.workers_to_close(n=2) + assert len(workers_to_close) == 2 + res = c.compute(shuffled) + + shuffle_id = await wait_until_new_shuffle_is_initialized(s) + while not get_active_shuffle_runs(a): + await asyncio.sleep(0.01) + while not get_active_shuffle_runs(b): + await asyncio.sleep(0.01) + + workers_to_close = s.workers_to_close(n=2) + assert len(workers_to_close) == 0 + + async with Worker(s.address) as w: + c.submit(lambda: None, workers=[w.address]) + + workers_to_close = s.workers_to_close(n=3) + assert len(workers_to_close) == 1 + + workers_to_close = s.workers_to_close(n=2) + assert len(workers_to_close) == 0 + + await c.gather(res) + del res + + workers_to_close = s.workers_to_close(n=2) + assert len(workers_to_close) == 2