Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] Scatter shuffle proof-of-concept #5473

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,16 @@ def __setstate__(self, state):
c = Client.current(allow_global=False)
except ValueError:
c = get_client(address)
self.__init__(key, c)
c._send_to_scheduler(
{
"op": "update-graph",
"tasks": {},
"keys": [stringify(self.key)],
"client": c.id,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is about updating the who_wants on scheduler side. however, I don't know for sure

}
)
self.__init__(key, c, inform=False) # HACK inform!!
# TODO why was this here? Is it safe to remove?
# c._send_to_scheduler(
# {
# "op": "update-graph",
# "tasks": {},
# "keys": [stringify(self.key)],
# "client": c.id,
# }
# )

def __del__(self):
try:
Expand Down Expand Up @@ -1785,6 +1786,14 @@ def map(

async def _gather(self, futures, errors="raise", direct=None, local_worker=None):
unpacked, future_set = unpack_remotedata(futures, byte_keys=True)
mismatched_futures = [f for f in future_set if f.client is not self]
if mismatched_futures:
raise ValueError(
"Cannot gather Futures created by another client. "
f"These are the {len(mismatched_futures)} (out of {len(futures)}) mismatched Futures and their client IDs "
f"(this client is {self.id}): "
f"{ {f: f.client.id for f in mismatched_futures} }"
)
keys = [stringify(future.key) for future in future_set]
bad_data = dict()
data = {}
Expand Down Expand Up @@ -2036,6 +2045,7 @@ async def _scatter(
who_has={key: [local_worker.address] for key in data},
nbytes=valmap(sizeof, data),
client=self.id,
report=False,
)

else:
Expand Down Expand Up @@ -2181,15 +2191,17 @@ def scatter(
hash=hash,
)

async def _cancel(self, futures, force=False):
async def _cancel(self, futures, force=False, _report=True):
keys = list({stringify(f.key) for f in futures_of(futures)})
await self.scheduler.cancel(keys=keys, client=self.id, force=force)
await self.scheduler.cancel(
keys=keys, client=self.id, force=force, _report=_report
)
for k in keys:
st = self.futures.pop(k, None)
if st is not None:
st.cancel()

def cancel(self, futures, asynchronous=None, force=False):
def cancel(self, futures, asynchronous=None, force=False, _report=True):
"""
Cancel running futures

Expand All @@ -2203,7 +2215,13 @@ def cancel(self, futures, asynchronous=None, force=False):
force : boolean (False)
Cancel this future even if other clients desire it
"""
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)
return self.sync(
self._cancel,
futures,
asynchronous=asynchronous,
force=force,
_report=_report,
)

async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
Expand Down
6 changes: 3 additions & 3 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ async def _start(self):
)
self._cluster_info.update(info)

self.periodic_callbacks["sync-cluster-info"] = PeriodicCallback(
self._sync_cluster_info, self._sync_interval * 1000
)
# self.periodic_callbacks["sync-cluster-info"] = PeriodicCallback(
# self._sync_cluster_info, self._sync_interval * 1000
# )
for pc in self.periodic_callbacks.values():
pc.start()
self.status = Status.running
Expand Down
35 changes: 19 additions & 16 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4956,17 +4956,17 @@ def remove_worker_from_events():

return "OK"

def stimulus_cancel(self, comm, keys=None, client=None, force=False):
def stimulus_cancel(self, comm, keys=None, client=None, force=False, _report=True):
"""Stop execution on a list of keys"""
logger.info("Client %s requests to cancel %d keys", client, len(keys))
if client:
# logger.info("Client %s requests to cancel %d keys", client, len(keys))
if client and _report:
self.log_event(
client, {"action": "cancel", "count": len(keys), "force": force}
)
for key in keys:
self.cancel_key(key, client, force=force)
self.cancel_key(key, client, force=force, _report=_report)

def cancel_key(self, key, client, retries=5, force=False):
def cancel_key(self, key, client, retries=5, force=False, _report=True):
"""Cancel a particular key and all dependents"""
# TODO: this should be converted to use the transition mechanism
parent: SchedulerState = cast(SchedulerState, self)
Expand All @@ -4985,13 +4985,14 @@ def cancel_key(self, key, client, retries=5, force=False):
if force or ts._who_wants == {cs}: # no one else wants this key
for dts in list(ts._dependents):
self.cancel_key(dts._key, client, force=force)
logger.info("Scheduler cancels key %s. Force=%s", key, force)
self.report({"op": "cancelled-key", "key": key})
# logger.info("Scheduler cancels key %s. Force=%s", key, force)
if _report:
self.report({"op": "cancelled-key", "key": key})
clients = list(ts._who_wants) if force else [cs]
for cs in clients:
self.client_releases_keys(keys=[key], client=cs._client_key)

def client_desires_keys(self, keys=None, client=None):
def client_desires_keys(self, keys=None, client=None, report=True):
parent: SchedulerState = cast(SchedulerState, self)
cs: ClientState = parent._clients.get(client)
if cs is None:
Expand All @@ -5006,7 +5007,7 @@ def client_desires_keys(self, keys=None, client=None):
ts._who_wants.add(cs)
cs._wants_what.add(ts)

if ts._state in ("memory", "erred"):
if report and ts._state in ("memory", "erred"):
self.report_on_key(ts=ts, client=client)

def client_releases_keys(self, keys=None, client=None):
Expand Down Expand Up @@ -6761,6 +6762,7 @@ def update_data(
who_has: dict,
nbytes: dict,
client=None,
report=True,
serializers=None,
):
"""
Expand All @@ -6772,9 +6774,7 @@ def update_data(
"""
parent: SchedulerState = cast(SchedulerState, self)
with log_errors():
who_has = {
k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items()
}
# TODO add `coerce_address` back for some cases
logger.debug("Update data %s", who_has)

for key, workers in who_has.items():
Expand All @@ -6790,12 +6790,15 @@ def update_data(
ws: WorkerState = parent._workers_dv[w]
if ws not in ts._who_has:
parent.add_replica(ts, ws)
self.report(
{"op": "key-in-memory", "key": key, "workers": list(workers)}
)
if report:
self.report(
{"op": "key-in-memory", "key": key, "workers": list(workers)}
)

if client:
self.client_desires_keys(keys=list(who_has), client=client)
self.client_desires_keys(
keys=list(who_has), client=client, report=report
)

def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None):
parent: SchedulerState = cast(SchedulerState, self)
Expand Down
Empty file.
119 changes: 119 additions & 0 deletions distributed/shuffle/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar

from dask.base import tokenize
from dask.dataframe import DataFrame
from dask.dataframe.core import _concat
from dask.dataframe.shuffle import shuffle_group
from dask.highlevelgraph import HighLevelGraph
from dask.sizeof import sizeof

from distributed import Future, get_client

if TYPE_CHECKING:
import pandas as pd


T = TypeVar("T")


class QuickSizeof(Generic[T]):
"Wrapper to bypass slow `sizeof` calls"

def __init__(self, obj: T, size: int) -> None:
self.obj = obj
self.size = size

def __sizeof__(self) -> int:
return self.size


def split(
df: pd.DataFrame,
column: str,
npartitions_output: int,
ignore_index: bool,
name: str,
row_size_estimate: int,
partition_info: dict[str, int] = None,
) -> dict[int, Future]:
"Split input partition into shards per output group; scatter shards and return Futures referencing them."
assert isinstance(partition_info, dict), "partition_info is not a dict"
client = get_client()

shards: dict[int, pd.DataFrame] = shuffle_group(
df,
cols=column,
stage=0,
k=npartitions_output,
npartitions=npartitions_output,
ignore_index=ignore_index,
nfinal=npartitions_output,
)
input_partition_i = partition_info["number"]
# Change keys to be unique among all tasks---the dict keys here end up being
# the task keys on the scheduler.
# Also wrap in `QuickSizeof` to significantly speed up the worker storing each
# shard in its zict buffer.
shards_rekeyed = {
# NOTE: this name is optimized to be easy for `key_split` to process
f"{name}-{input_partition_i}-{output_partition_i}": QuickSizeof(
shard, len(shard) * row_size_estimate
)
for output_partition_i, shard in shards.items()
}
# NOTE: `scatter` called within a task has very different (undocumented) behavior:
# it writes the keys directly to the current worker, then informs the scheduler
# that these keys exist on the current worker. No communications to other workers ever.
futures: dict[str, Future] = client.scatter(shards_rekeyed)
# Switch keys back to output partition numbers so they're easier to select
return dict(zip(shards, futures.values()))


def gather_regroup(i: int, all_futures: list[dict[int, Future]]) -> pd.DataFrame:
"Given Futures for all shards, select Futures for this output partition, gather them, and concat."
client = get_client()
futures = [fs[i] for fs in all_futures if i in fs]
for f in futures:
# HACK: we disabled informing on deserialized futures, so manually mark them as finished
if not f.done():
f._state.finish()
shards: list[QuickSizeof[pd.DataFrame]] = client.gather(futures, direct=True)
# # Since every worker holds a reference to all futures until the very last task completes,
# # forcibly cancel these futures now to allow memory to be released eagerly.
# # This is safe because we're only cancelling futures for this output partition,
# # and there's exactly one task for each output partition.
# client.cancel(futures, force=True, _report=False)

return _concat([s.obj for s in shards])


def rearrange_by_column_scatter(
df: DataFrame, column: str, npartitions=None, ignore_index=False
) -> DataFrame:
token = tokenize(df, column)

npartitions = npartitions or df.npartitions
row_size_estimate = sizeof(df._meta_nonempty) // len(df._meta_nonempty)
splits = df.map_partitions(
split,
column,
npartitions,
ignore_index,
f"shuffle-shards-{token}",
row_size_estimate,
meta=df,
enforce_metadata=False,
transform_divisions=False,
)

all_futures = splits.__dask_keys__()
name = f"shuffle-regroup-{token}"
dsk = {(name, i): (gather_regroup, i, all_futures) for i in range(npartitions)}
return DataFrame(
HighLevelGraph.from_collections(name, dsk, [splits]),
name,
df._meta,
[None] * (npartitions + 1),
)
11 changes: 11 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,17 @@ async def test_gather(c, s, a, b):
assert result == {"x": 11, "y": [12]}


@gen_cluster(client=True)
async def test_gather_mismatched_client(c, s, a, b):
c2 = await Client(s.address, asynchronous=True)

x = c.submit(inc, 10)
y = c2.submit(inc, 5)

with pytest.raises(ValueError, match="Futures created by another client"):
await c.gather([x, y])


@gen_cluster(client=True)
async def test_gather_lost(c, s, a, b):
[x] = await c.scatter([1], workers=a.address)
Expand Down
17 changes: 17 additions & 0 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import array
import asyncio
import contextvars
import functools
import io
import os
Expand Down Expand Up @@ -554,6 +555,22 @@ async def test_offload():
assert (await offload(lambda x, y: x + y, 1, y=2)) == 3


@pytest.mark.asyncio
async def test_offload_preserves_contextvars():
var = contextvars.ContextVar("var", default="foo")

def change_var():
var.set("bar")
return var.get()

o1 = offload(var.get)
o2 = offload(change_var)

r1, r2 = await asyncio.gather(o1, o2)
assert (r1, r2) == ("foo", "bar")
assert var.get() == "foo"


def test_serialize_for_cli_deprecated():
with pytest.warns(FutureWarning, match="serialize_for_cli is deprecated"):
from distributed.utils import serialize_for_cli
Expand Down
23 changes: 23 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TaskStateMetadataPlugin,
_LockedCommPool,
captured_logger,
cluster,
dec,
div,
gen_cluster,
Expand Down Expand Up @@ -965,6 +966,28 @@ def f(x):
assert a._client is a_client


@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)])
async def test_get_client_threadsafe(c, s, a):
def f(x):
return get_client().id

futures = c.map(f, range(100))
ids = await c.gather(futures)
assert len(set(ids)) == 1


def test_get_client_threadsafe_sync():
def f(x):
return get_client().id

with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers):
with Client(scheduler["address"]) as client:
futures = client.map(f, range(100))
ids = client.gather(futures)
assert len(set(ids)) == 1
assert set(ids) != {client.id}


def test_get_client_sync(client):
def f(x):
cc = get_client()
Expand Down
Loading