Skip to content

Commit

Permalink
ensure workers are not downscaled when participating in p2p (#8610)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Apr 12, 2024
1 parent 66ced13 commit e0a7525
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
24 changes: 24 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# circular imports
from distributed.scheduler import Scheduler
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
from distributed.scheduler import WorkerState
from distributed.worker import Worker
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState

Expand Down Expand Up @@ -205,6 +206,29 @@ def add_client(self, scheduler: Scheduler, client: str) -> None:
def remove_client(self, scheduler: Scheduler, client: str) -> None:
"""Run when a client disconnects"""

def valid_workers_downscaling(

This comment has been minimized.

Copy link
@hendrikmakait

hendrikmakait Apr 16, 2024

Member

nit: I think it would be more natural if plugins told us about the workers they care about and don't want to drop rather than telling us about all the workers they don't care about. (No real action here given that this is already committed.)

self, scheduler: Scheduler, workers: list[WorkerState]
) -> list[WorkerState]:
"""Determine which workers can be removed from the cluster
This method is called when the scheduler is about to downscale the cluster
by removing workers. The method should return a set of worker states that
can be removed from the cluster.
Parameters
----------
workers : list
The list of worker states that are candidates for removal.
stimulus_id : str
ID of stimulus causing the downscaling.
Returns
-------
list
The list of worker states that can be removed from the cluster.
"""
return workers

def log_event(self, topic: str, msg: Any) -> None:
"""Run when an event is logged"""

Expand Down
3 changes: 3 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7153,6 +7153,9 @@ def workers_to_close(
# running on, as it would cause them to restart from scratch
# somewhere else.
valid_workers = [ws for ws in self.workers.values() if not ws.long_running]
for plugin in list(self.plugins.values()):
valid_workers = plugin.valid_workers_downscaling(self, valid_workers)

groups = groupby(key, valid_workers)

limit_bytes = {k: sum(ws.memory_limit for ws in v) for k, v in groups.items()}
Expand Down
8 changes: 8 additions & 0 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,14 @@ def transition(
if not archived:
del self._archived_by_stimulus[shuffle._archived_by]

def valid_workers_downscaling(
self, scheduler: Scheduler, workers: list[WorkerState]
) -> list[WorkerState]:
all_participating_workers = set()
for shuffle in self.active_shuffles.values():
all_participating_workers.update(shuffle.participating_workers)
return [w for w in workers if w.address not in all_participating_workers]

def _fail_on_workers(self, shuffle: SchedulerShuffleState, message: str) -> None:
worker_msgs = {
worker: [
Expand Down
49 changes: 45 additions & 4 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop

import dask
from dask.utils import key_split

from distributed.shuffle._core import ShuffleId, ShuffleRun, barrier_key
Expand All @@ -28,7 +29,6 @@
import numpy as np
import pandas as pd

import dask
from dask.dataframe._compat import PANDAS_GE_150, PANDAS_GE_200
from dask.typing import Key

Expand Down Expand Up @@ -139,9 +139,10 @@ async def test_minimal_version(c, s, a, b):
dtypes={"x": float, "y": float},
freq="10 s",
)
with pytest.raises(
ModuleNotFoundError, match="requires pyarrow"
), dask.config.set({"dataframe.shuffle.method": "p2p"}):
with (
pytest.raises(ModuleNotFoundError, match="requires pyarrow"),
dask.config.set({"dataframe.shuffle.method": "p2p"}),
):
await c.compute(df.shuffle("x"))


Expand Down Expand Up @@ -2795,3 +2796,43 @@ def data_gen():
"meta",
):
await c.gather(c.compute(ddf.shuffle(on="a")))


@gen_cluster(client=True)
async def test_dont_downscale_participating_workers(c, s, a, b):
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
with dask.config.set({"dataframe.shuffle.method": "p2p"}):
shuffled = df.shuffle("x")

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2
res = c.compute(shuffled)

shuffle_id = await wait_until_new_shuffle_is_initialized(s)
while not get_active_shuffle_runs(a):
await asyncio.sleep(0.01)
while not get_active_shuffle_runs(b):
await asyncio.sleep(0.01)

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 0

async with Worker(s.address) as w:
c.submit(lambda: None, workers=[w.address])

workers_to_close = s.workers_to_close(n=3)
assert len(workers_to_close) == 1

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 0

await c.gather(res)
del res

workers_to_close = s.workers_to_close(n=2)
assert len(workers_to_close) == 2

0 comments on commit e0a7525

Please sign in to comment.