Skip to content

Commit

Permalink
Fix P2P shuffle with LocalCluster(..., processes=False) (dask#8125)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Aug 23, 2023
1 parent 6584e5c commit 03ea2e1
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
6 changes: 5 additions & 1 deletion distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,14 @@ def get(self, id: ShuffleId, worker: str) -> ToPickle[ShuffleRunSpec]:

def get_or_create(
self,
spec: ShuffleSpec,
# FIXME: This should never be ToPickle[ShuffleSpec]
spec: ShuffleSpec | ToPickle[ShuffleSpec],
key: str,
worker: str,
) -> ToPickle[ShuffleRunSpec]:
# FIXME: Sometimes, this doesn't actually get pickled
if isinstance(spec, ToPickle):
spec = spec.data
try:
return self.get(spec.id, worker)
except KeyError:
Expand Down
18 changes: 8 additions & 10 deletions distributed/shuffle/_worker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async def _get_or_create_shuffle(
if shuffle is None:
shuffle = await self._refresh_shuffle(
shuffle_id=spec.id,
spec=ToPickle(spec),
spec=spec,
key=key,
)

Expand All @@ -239,33 +239,32 @@ async def _refresh_shuffle(
async def _refresh_shuffle(
self,
shuffle_id: ShuffleId,
spec: ToPickle,
spec: ShuffleSpec,
key: str,
) -> ShuffleRun:
...

async def _refresh_shuffle(
self,
shuffle_id: ShuffleId,
spec: ToPickle | None = None,
spec: ShuffleSpec | None = None,
key: str | None = None,
) -> ShuffleRun:
result: ShuffleRunSpec
# FIXME: This should never be ToPickle[ShuffleRunSpec]
result: ShuffleRunSpec | ToPickle[ShuffleRunSpec]
if spec is None:
result = await self.worker.scheduler.shuffle_get(
id=shuffle_id,
worker=self.worker.address,
)
else:
result = await self.worker.scheduler.shuffle_get_or_create(
spec=spec,
spec=ToPickle(spec),
key=key,
worker=self.worker.address,
)
# if result["status"] == "error":
# raise RuntimeError(result["message"])
# assert result["status"] == "OK"

if isinstance(result, ToPickle):
result = result.data
if self.closed:
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_id in self.shuffles:
Expand All @@ -287,7 +286,6 @@ async def _(
extension._runs_cleanup_condition.notify_all()

self.worker._ongoing_background_tasks.call_soon(_, self, existing)

shuffle: ShuffleRun = result.spec.create_run_on_worker(
result.run_id, result.worker_for, self
)
Expand Down
24 changes: 23 additions & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
dd = pytest.importorskip("dask.dataframe")

import dask
from dask.distributed import Event, Nanny, Worker
from dask.distributed import Event, LocalCluster, Nanny, Worker
from dask.utils import stringify

from distributed.client import Client
Expand Down Expand Up @@ -187,6 +187,28 @@ async def test_basic_integration(c, s, a, b, lose_annotations, npartitions):
await check_scheduler_cleanup(s)


@pytest.mark.parametrize("processes", [True, False])
@gen_test()
async def test_basic_integration_local_cluster(processes):
async with LocalCluster(
n_workers=2,
processes=processes,
asynchronous=True,
dashboard_address=":0",
) as cluster:
df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
)
c = cluster.get_client()
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
x, y = c.compute([df, out])
x, y = await c.gather([x, y])
dd.assert_eq(x, y)


@pytest.mark.parametrize("npartitions", [None, 1, 20])
@gen_cluster(client=True)
async def test_shuffle_with_array_conversion(c, s, a, b, lose_annotations, npartitions):
Expand Down

0 comments on commit 03ea2e1

Please sign in to comment.