Skip to content

Commit

Permalink
ensure workers are not downscaled when participating in p2p
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Apr 12, 2024
1 parent 66ced13 commit 9b302e7
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
24 changes: 24 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# circular imports
from distributed.scheduler import Scheduler
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
from distributed.scheduler import WorkerState
from distributed.worker import Worker
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState

Expand Down Expand Up @@ -205,6 +206,29 @@ def add_client(self, scheduler: Scheduler, client: str) -> None:
def remove_client(self, scheduler: Scheduler, client: str) -> None:
"""Run when a client disconnects"""

def valid_workers_downscaling(
self, scheduler: Scheduler, workers: list[WorkerState]
) -> list[WorkerState]:
"""Determine which workers can be removed from the cluster
This method is called when the scheduler is about to downscale the cluster
by removing workers. The method should return a set of worker states that
can be removed from the cluster.
Parameters
----------
workers : list
The list of worker states that are candidates for removal.
stimulus_id : str
ID of stimulus causing the downscaling.
Returns
-------
list
The list of worker states that can be removed from the cluster.
"""
return workers

def log_event(self, topic: str, msg: Any) -> None:
"""Run when an event is logged"""

Expand Down
3 changes: 3 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7153,6 +7153,9 @@ def workers_to_close(
# running on, as it would cause them to restart from scratch
# somewhere else.
valid_workers = [ws for ws in self.workers.values() if not ws.long_running]
for plugin in list(self.plugins.values()):
valid_workers = plugin.valid_workers_downscaling(self, valid_workers)

groups = groupby(key, valid_workers)

limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()}
Expand Down
8 changes: 8 additions & 0 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,14 @@ def transition(
if not archived:
del self._archived_by_stimulus[shuffle._archived_by]

def valid_workers_downscaling(
self, scheduler: Scheduler, workers: list[WorkerState]
) -> list[WorkerState]:
all_participating_workers = set()
for shuffle in self.active_shuffles.values():
all_participating_workers.update(shuffle.participating_workers)
return [w for w in workers if w.address not in all_participating_workers]

def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None:
worker_msgs = {
worker: [
Expand Down
49 changes: 45 additions & 4 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop

import dask
from dask.utils import key_split

from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
Expand All @@ -28,7 +29,6 @@
import numpy as np
import pandas as pd

import dask
from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200
from dask.typing import Key

Expand Down Expand Up @@ -139,9 +139,10 @@ async def test_minimal_version(c, s, a, b):
dtypes={"x": float, "y": float},
freq="10 s",
)
with pytest.raises(
ModuleNotFoundError, match="requires pyarrow"
), dask.config.set({"dataframe.shuffle.method": "p2p"}):
with (
pytest.raises(ModuleNotFoundError, match="requires pyarrow"),
dask.config.set({"dataframe.shuffle.method": "p2p"}),
):
await c.compute(df.shuffle("x"))


Expand Down Expand Up @@ -2795,3 +2796,43 @@ def data_gen():
"meta",
):
await c.gather(c.compute(ddf.shuffle(on="a")))


@gen_cluster(client=True)
async def test_dont_downscale_participating_workers(c, s, a, b):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2
res = c.compute(shuffled)

shuffle_id = await wait_until_new_shuffle_is_initialized(s)
while not get_active_shuffle_runs(a):
await asyncio.sleep(0.01)
while not get_active_shuffle_runs(b):
await asyncio.sleep(0.01)

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 0

async with Worker(s.address) as w:
c.submit(lambda: None, workers=[w.address])

workers_to_close = s.workers_to_close(n=3)
assert len(workers_to_close) == 1

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 0

await c.gather(res)
del res

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2

0 comments on commit 9b302e7

Please sign in to comment.