From d89514163f6d2027adc6d0734e9e00035a06372f Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 26 Oct 2021 23:26:15 -0600 Subject: [PATCH 01/17] Consistent worker Client instance in `get_client` Fixes #4959 `get_client` was calling the private `Worker._get_client` method when it ran within a task. `_get_client` should really have been called `_make_client`, since it created a new client every time. The simplest correct thing to do instead would have been to use the `Worker.client` property, which caches this instance. In order to pass the `timeout` parameter through though, I changed `Worker.get_client` to actually match its docstring and always return the same instance. --- distributed/tests/test_worker.py | 23 +++++++ distributed/worker.py | 100 +++++++++++++++---------------- 2 files changed, 73 insertions(+), 50 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 7a3126e9d1..ca87b2d24a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -44,6 +44,7 @@ TaskStateMetadataPlugin, _LockedCommPool, captured_logger, + cluster, dec, div, gen_cluster, @@ -965,6 +966,28 @@ def f(x): assert a._client is a_client +@gen_cluster(client=True, nthreads=[("127.0.0.1", 4)]) +async def test_get_client_threadsafe(c, s, a): + def f(x): + return get_client().id + + futures = c.map(f, range(100)) + ids = await c.gather(futures) + assert len(set(ids)) == 1 + + +def test_get_client_threadsafe_sync(): + def f(x): + return get_client().id + + with cluster(nworkers=1, worker_kwargs={"nthreads": 4}) as (scheduler, workers): + with Client(scheduler["address"]) as client: + futures = client.map(f, range(100)) + ids = client.gather(futures) + assert len(set(ids)) == 1 + assert set(ids) != {client.id} + + def test_get_client_sync(client): def f(x): cc = get_client() diff --git a/distributed/worker.py b/distributed/worker.py index ee17e115d4..caf0103e18 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -449,7 +449,7 @@ def __init__( self.has_what = defaultdict(set) self.pending_data_per_worker = defaultdict(deque) self.nanny = nanny - self._lock = threading.Lock() + self._client_lock = threading.Lock() self.data_needed = [] @@ -3554,11 +3554,7 @@ def validate_state(self): @property def client(self) -> Client: - with self._lock: - if self._client: - return self._client - else: - return self._get_client() + return self._get_client() def _get_client(self, timeout=None) -> Client: """Get local client attached to this worker @@ -3569,56 +3565,60 @@ def _get_client(self, timeout=None) -> Client: -------- get_client """ + with self._client_lock: + if self._client: + return self._client - if timeout is None: - timeout = dask.config.get("distributed.comm.timeouts.connect") - - timeout = parse_timedelta(timeout, "s") + if timeout is None: + timeout = dask.config.get("distributed.comm.timeouts.connect") - try: - from .client import default_client + timeout = parse_timedelta(timeout, "s") - client = default_client() - except ValueError: # no clients found, need to make a new one - pass - else: - # must be lazy import otherwise cyclic import - from distributed.deploy.cluster import Cluster + try: + from .client import default_client - if ( - client.scheduler - and client.scheduler.address == self.scheduler.address - # The below conditions should only happen in case a second - # cluster is alive, e.g. if a submitted task spawned its onwn - # LocalCluster, see gh4565 - or ( - isinstance(client._start_arg, str) - and client._start_arg == self.scheduler.address - or isinstance(client._start_arg, Cluster) - and client._start_arg.scheduler_address == self.scheduler.address + client = default_client() + except ValueError: # no clients found, need to make a new one + pass + else: + # must be lazy import otherwise cyclic import + from distributed.deploy.cluster import Cluster + + if ( + client.scheduler + and client.scheduler.address == self.scheduler.address + # The below conditions should only happen in case a second + # cluster is alive, e.g. if a submitted task spawned its onwn + # LocalCluster, see gh4565 + or ( + isinstance(client._start_arg, str) + and client._start_arg == self.scheduler.address + or isinstance(client._start_arg, Cluster) + and client._start_arg.scheduler_address + == self.scheduler.address + ) + ): + self._client = client + + if not self._client: + from .client import Client + + asynchronous = self.loop is IOLoop.current() + self._client = Client( + self.scheduler, + loop=self.loop, + security=self.security, + set_as_default=True, + asynchronous=asynchronous, + direct_to_workers=True, + name="worker", + timeout=timeout, ) - ): - self._client = client - - if not self._client: - from .client import Client - - asynchronous = self.loop is IOLoop.current() - self._client = Client( - self.scheduler, - loop=self.loop, - security=self.security, - set_as_default=True, - asynchronous=asynchronous, - direct_to_workers=True, - name="worker", - timeout=timeout, - ) - Worker._initialized_clients.add(self._client) - if not asynchronous: - assert self._client.status == "running" + Worker._initialized_clients.add(self._client) + if not asynchronous: + assert self._client.status == "running" - return self._client + return self._client def get_current_task(self): """Get the key of the task we are currently running From 5e8fcd4cb9f79c9f4d950fbd004a93785296fe72 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Tue, 26 Oct 2021 23:48:17 -0600 Subject: [PATCH 02/17] Check for Futures from the wrong Client in `gather` If you accidentally pass Futures created by a different Client into `Client.gather`, you'll get a `CancelledError`. This is confusing and misleading. An explicit check for this would have made discovering https://github.com/dask/distributed/issues/5466 much easier. And since there are probably plenty of other race conditions regarding default clients in multiple threads, hopefully a better error message will save someone else time in the future too. --- distributed/client.py | 8 ++++++++ distributed/tests/test_client.py | 11 +++++++++++ 2 files changed, 19 insertions(+) diff --git a/distributed/client.py b/distributed/client.py index cf370ec469..dc24f29bc5 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -1785,6 +1785,14 @@ def map( async def _gather(self, futures, errors="raise", direct=None, local_worker=None): unpacked, future_set = unpack_remotedata(futures, byte_keys=True) + mismatched_futures = [f for f in future_set if f.client is not self] + if mismatched_futures: + raise ValueError( + "Cannot gather Futures created by another client. " + f"These are the {len(mismatched_futures)} (out of {len(futures)}) mismatched Futures and their client IDs " + f"(this client is {self.id}): " + f"{ {f: f.client.id for f in mismatched_futures} }" + ) keys = [stringify(future.key) for future in future_set] bad_data = dict() data = {} diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 982212bf74..4a100aaa32 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -572,6 +572,17 @@ async def test_gather(c, s, a, b): assert result == {"x": 11, "y": [12]} +@gen_cluster(client=True) +async def test_gather_mismatched_client(c, s, a, b): + c2 = await Client(s.address, asynchronous=True) + + x = c.submit(inc, 10) + y = c2.submit(inc, 5) + + with pytest.raises(ValueError, match="Futures created by another client"): + await c.gather([x, y]) + + @gen_cluster(client=True) async def test_gather_lost(c, s, a, b): [x] = await c.scatter([1], workers=a.address) From 0cd61dfa0c8c215e148a1aecf4458bf037ab6943 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 09:43:32 -0600 Subject: [PATCH 03/17] POC for scatter-based shuffle --- distributed/shuffle/__init__.py | 0 distributed/shuffle/scatter.py | 109 ++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) create mode 100644 distributed/shuffle/__init__.py create mode 100644 distributed/shuffle/scatter.py diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py new file mode 100644 index 0000000000..48c4b74f1e --- /dev/null +++ b/distributed/shuffle/scatter.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from dask.base import tokenize +from dask.dataframe import DataFrame +from dask.dataframe.core import _concat +from dask.dataframe.shuffle import shuffle_group +from dask.highlevelgraph import HighLevelGraph +from dask.sizeof import sizeof + +from distributed import Future, get_client + +if TYPE_CHECKING: + import pandas as pd + + +T = TypeVar("T") + + +class QuickSizeof(Generic[T]): + "Wrapper to bypass slow `sizeof` calls" + + def __init__(self, obj: T, size: int) -> None: + self.obj = obj + self.size = size + + def __sizeof__(self) -> int: + return self.size + + +def split( + df: pd.DataFrame, + column: str, + npartitions_output: int, + ignore_index: bool, + name: str, + row_size_estimate: int, + partition_info: dict[str, int] = None, +) -> dict[int, Future]: + assert isinstance(partition_info, dict), "partition_info is not a dict" + client = get_client() + + shards: dict[int, pd.DataFrame] = shuffle_group( + df, + cols=column, + stage=0, + k=npartitions_output, + npartitions=npartitions_output, + ignore_index=ignore_index, + nfinal=npartitions_output, + ) + input_partition_i = partition_info["number"] + # Change keys to be unique among all tasks---the dict keys here end up being + # the task keys on the scheduler. + # Also wrap in `QuickSizeof` to significantly speed up the worker storing each + # shard in its zict buffer. + shards_rekeyed = { + f"({name!r}, {input_partition_i}, {output_partition_i})": QuickSizeof( + shard, len(shard) * row_size_estimate + ) + for output_partition_i, shard in shards.items() + } + # NOTE: `scatter` called within a task has very different (undocumented) behavior: + # it writes the keys directly to the current worker, then informs the scheduler + # that these keys exist on the current worker. No communications to other workers ever. + futures: dict[str, Future] = client.scatter(shards_rekeyed) + return dict(zip(shards, futures.values())) + + +def gather_regroup(i: int, all_futures: list[dict[int, Future]]) -> pd.DataFrame: + client = get_client() + futures = [fs[i] for fs in all_futures if i in fs] + shards: list[QuickSizeof[pd.DataFrame]] = client.gather(futures, direct=True) + # Since every worker holds a reference to all futures until the very last task completes, + # forcibly cancel these futures now to allow memory to be released eagerly. + # This is safe because we're only cancelling futures for this output partition, + # and there's exactly one task for each output partition. + client.cancel(futures, force=True) + + return _concat([s.obj for s in shards]) + + +def rearrange_by_column_scatter( + df: DataFrame, column: str, npartitions=None, ignore_index=False +) -> DataFrame: + token = tokenize(df, column) + + npartitions = npartitions or df.npartitions + row_size_estimate = sizeof(df._meta_nonempty) // len(df._meta_nonempty) + splits = df.map_partitions( + split, + column, + npartitions, + ignore_index, + f"shuffle-split-{token}", + row_size_estimate, + meta=df, + ) + + all_futures = splits.__dask_keys__() + name = f"shuffle-regroup-{token}" + dsk = {(name, i): (gather_regroup, i, all_futures) for i in range(npartitions)} + return DataFrame( + HighLevelGraph.from_collections(name, dsk, [splits]), + name, + df._meta, + [None] * (npartitions + 1), + ) From 30af07e2bf705001a0c7c260edf45077e59e6122 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 09:45:40 -0600 Subject: [PATCH 04/17] Set current client in worker while deserializing dependencies This is probably a good idea in general (xref https://github.com/dask/distributed/issues/4959), but it particularly helps with performance deserializing Futures, which have a fastpath through `Client.current` that bypasses a number of unnecessarily slow things that `get_client` does before it checks `Client.current`. --- distributed/worker.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/distributed/worker.py b/distributed/worker.py index caf0103e18..7a4347788d 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2533,11 +2533,12 @@ async def gather_dep( worker, ) - start = time() - response = await get_data_from_worker( - self.rpc, to_gather_keys, worker, who=self.address - ) - stop = time() + with self.client.as_current(): + start = time() + response = await get_data_from_worker( + self.rpc, to_gather_keys, worker, who=self.address + ) + stop = time() if response["status"] == "busy": return From 3399830773f1630d96ff4182e76043af5fa145ef Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 11:16:04 -0600 Subject: [PATCH 05/17] Don't report back locally-scattered keys When a key is scattered from a task and written directly to worker storage, the Client immediately sets the Future's state to `"finished"`. There's no need for the scheduler to also tell the client that that key is finished; it already knows. This saves a bit of scheduler time and a comms roundtrip. --- distributed/client.py | 1 + distributed/scheduler.py | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index dc24f29bc5..a47befa266 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2044,6 +2044,7 @@ async def _scatter( who_has={key: [local_worker.address] for key in data}, nbytes=valmap(sizeof, data), client=self.id, + report=False, ) else: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index baa7957334..2e7360ac71 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4991,7 +4991,7 @@ def cancel_key(self, key, client, retries=5, force=False): for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) - def client_desires_keys(self, keys=None, client=None): + def client_desires_keys(self, keys=None, client=None, report=True): parent: SchedulerState = cast(SchedulerState, self) cs: ClientState = parent._clients.get(client) if cs is None: @@ -5006,7 +5006,7 @@ def client_desires_keys(self, keys=None, client=None): ts._who_wants.add(cs) cs._wants_what.add(ts) - if ts._state in ("memory", "erred"): + if report and ts._state in ("memory", "erred"): self.report_on_key(ts=ts, client=client) def client_releases_keys(self, keys=None, client=None): @@ -6761,6 +6761,7 @@ def update_data( who_has: dict, nbytes: dict, client=None, + report=True, serializers=None, ): """ @@ -6795,7 +6796,9 @@ def update_data( ) if client: - self.client_desires_keys(keys=list(who_has), client=client) + self.client_desires_keys( + keys=list(who_has), client=client, report=report + ) def report_on_key(self, key: str = None, ts: TaskState = None, client: str = None): parent: SchedulerState = cast(SchedulerState, self) From 555be9b83141e0cc7408ad39ccd60e5bdf713de1 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 11:24:16 -0600 Subject: [PATCH 06/17] also don't report to workers (is this right?) --- distributed/scheduler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 2e7360ac71..38402ca1fc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6791,9 +6791,10 @@ def update_data( ws: WorkerState = parent._workers_dv[w] if ws not in ts._who_has: parent.add_replica(ts, ws) - self.report( - {"op": "key-in-memory", "key": key, "workers": list(workers)} - ) + if report: + self.report( + {"op": "key-in-memory", "key": key, "workers": list(workers)} + ) if client: self.client_desires_keys( From 8b2d7b0686c63f8d5a73a44ca8397235527ec142 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 11:27:56 -0600 Subject: [PATCH 07/17] scatter docstrings + comments --- distributed/shuffle/scatter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index 48c4b74f1e..88b8d012b2 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -38,6 +38,7 @@ def split( row_size_estimate: int, partition_info: dict[str, int] = None, ) -> dict[int, Future]: + "Split input partition into shards per output group; scatter shards and return Futures referencing them." assert isinstance(partition_info, dict), "partition_info is not a dict" client = get_client() @@ -65,10 +66,12 @@ def split( # it writes the keys directly to the current worker, then informs the scheduler # that these keys exist on the current worker. No communications to other workers ever. futures: dict[str, Future] = client.scatter(shards_rekeyed) + # Switch keys back to output partition numbers so they're easier to select return dict(zip(shards, futures.values())) def gather_regroup(i: int, all_futures: list[dict[int, Future]]) -> pd.DataFrame: + "Given Futures for all shards, select Futures for this output partition, gather them, and concat." client = get_client() futures = [fs[i] for fs in all_futures if i in fs] shards: list[QuickSizeof[pd.DataFrame]] = client.gather(futures, direct=True) From 189867359555e41c200aabe3ee3ed21b52f62dc4 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 11:28:16 -0600 Subject: [PATCH 08/17] faster names to key_split --- distributed/shuffle/scatter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index 88b8d012b2..fdaf908404 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -57,7 +57,8 @@ def split( # Also wrap in `QuickSizeof` to significantly speed up the worker storing each # shard in its zict buffer. shards_rekeyed = { - f"({name!r}, {input_partition_i}, {output_partition_i})": QuickSizeof( + # NOTE: this name is optimized to be easy for `key_split` to process + f"{name}-{input_partition_i}-{output_partition_i}": QuickSizeof( shard, len(shard) * row_size_estimate ) for output_partition_i, shard in shards.items() From 90e1c478b732868877a6acb489f9ef5be06dde58 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 12:20:02 -0600 Subject: [PATCH 09/17] Preserve contextvars during comm offload Helps with setting the current client in worker while deserializing. Implementation referenced from https://github.com/python/cpython/pull/9688 --- distributed/tests/test_utils.py | 17 +++++++++++++++++ distributed/utils.py | 7 ++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 0a698a3b3a..849e793f77 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -1,5 +1,6 @@ import array import asyncio +import contextvars import functools import io import os @@ -554,6 +555,22 @@ async def test_offload(): assert (await offload(lambda x, y: x + y, 1, y=2)) == 3 +@pytest.mark.asyncio +async def test_offload_preserves_contextvars(): + var = contextvars.ContextVar("var", default="foo") + + def change_var(): + var.set("bar") + return var.get() + + o1 = offload(var.get) + o2 = offload(change_var) + + r1, r2 = await asyncio.gather(o1, o2) + assert (r1, r2) == ("foo", "bar") + assert var.get() == "foo" + + def test_serialize_for_cli_deprecated(): with pytest.warns(FutureWarning, match="serialize_for_cli is deprecated"): from distributed.utils import serialize_for_cli diff --git a/distributed/utils.py b/distributed/utils.py index 4e79e3b36d..952de3482e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import functools import importlib import inspect @@ -1322,7 +1323,11 @@ def import_term(name: str): async def offload(fn, *args, **kwargs): loop = asyncio.get_event_loop() - return await loop.run_in_executor(_offload_executor, lambda: fn(*args, **kwargs)) + # Retain context vars while deserializing; see https://bugs.python.org/issue34014 + context = contextvars.copy_context() + return await loop.run_in_executor( + _offload_executor, lambda: context.run(fn, *args, **kwargs) + ) class EmptyContext: From 6d9d09034d8b3136c02818904535b297ee290d57 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 12:51:26 -0600 Subject: [PATCH 10/17] Remove maybe-superfluous message in setstate --- distributed/client.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a47befa266..a9d81d1b83 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -379,14 +379,15 @@ def __setstate__(self, state): except ValueError: c = get_client(address) self.__init__(key, c) - c._send_to_scheduler( - { - "op": "update-graph", - "tasks": {}, - "keys": [stringify(self.key)], - "client": c.id, - } - ) + # TODO why was this here? Is it safe to remove? + # c._send_to_scheduler( + # { + # "op": "update-graph", + # "tasks": {}, + # "keys": [stringify(self.key)], + # "client": c.id, + # } + # ) def __del__(self): try: From 673ea24a574891897097d412dceda7e446ff8617 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 12:52:21 -0600 Subject: [PATCH 11/17] FIXME remove address coercion in update_data This was really slow and probably doesn't matter when the future is coming from a worker. But probably not safe to remove in general? --- distributed/scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 38402ca1fc..af8f86b760 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6773,9 +6773,7 @@ def update_data( """ parent: SchedulerState = cast(SchedulerState, self) with log_errors(): - who_has = { - k: [self.coerce_address(vv) for vv in v] for k, v in who_has.items() - } + # TODO add `coerce_address` back for some cases logger.debug("Update data %s", who_has) for key, workers in who_has.items(): From 43690286164062418bd8a29483cece2e145371ca Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 13:18:06 -0600 Subject: [PATCH 12/17] no enforce_metadata or transform_divisions --- distributed/shuffle/scatter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index fdaf908404..eef19f6457 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -100,6 +100,8 @@ def rearrange_by_column_scatter( f"shuffle-split-{token}", row_size_estimate, meta=df, + enforce_metadata=False, + transform_divisions=False, ) all_futures = splits.__dask_keys__() From d485eaac68d99c50612ff5de489a83417ce80c51 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 15:22:15 -0600 Subject: [PATCH 13/17] shuffle-split -> shuffle-shards --- distributed/shuffle/scatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index eef19f6457..aed9f45f3b 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -97,7 +97,7 @@ def rearrange_by_column_scatter( column, npartitions, ignore_index, - f"shuffle-split-{token}", + f"shuffle-shards-{token}", row_size_estimate, meta=df, enforce_metadata=False, From 23b1e658f9d02c153af952b0772abb25b7dd572b Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 16:39:00 -0600 Subject: [PATCH 14/17] REVERTME no-report cancellation We don't need to report back to the client that its key was cancelled. But this shouldn't be exposed and may be wrong. --- distributed/client.py | 16 ++++++++++++---- distributed/scheduler.py | 11 ++++++----- distributed/shuffle/scatter.py | 2 +- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a9d81d1b83..8bd551967d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2191,15 +2191,17 @@ def scatter( hash=hash, ) - async def _cancel(self, futures, force=False): + async def _cancel(self, futures, force=False, _report=True): keys = list({stringify(f.key) for f in futures_of(futures)}) - await self.scheduler.cancel(keys=keys, client=self.id, force=force) + await self.scheduler.cancel( + keys=keys, client=self.id, force=force, _report=_report + ) for k in keys: st = self.futures.pop(k, None) if st is not None: st.cancel() - def cancel(self, futures, asynchronous=None, force=False): + def cancel(self, futures, asynchronous=None, force=False, _report=True): """ Cancel running futures @@ -2213,7 +2215,13 @@ def cancel(self, futures, asynchronous=None, force=False): force : boolean (False) Cancel this future even if other clients desire it """ - return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force) + return self.sync( + self._cancel, + futures, + asynchronous=asynchronous, + force=force, + _report=_report, + ) async def _retry(self, futures): keys = list({stringify(f.key) for f in futures_of(futures)}) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index af8f86b760..8fe4ca10a1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4956,17 +4956,17 @@ def remove_worker_from_events(): return "OK" - def stimulus_cancel(self, comm, keys=None, client=None, force=False): + def stimulus_cancel(self, comm, keys=None, client=None, force=False, _report=True): """Stop execution on a list of keys""" logger.info("Client %s requests to cancel %d keys", client, len(keys)) - if client: + if client and _report: self.log_event( client, {"action": "cancel", "count": len(keys), "force": force} ) for key in keys: - self.cancel_key(key, client, force=force) + self.cancel_key(key, client, force=force, _report=_report) - def cancel_key(self, key, client, retries=5, force=False): + def cancel_key(self, key, client, retries=5, force=False, _report=True): """Cancel a particular key and all dependents""" # TODO: this should be converted to use the transition mechanism parent: SchedulerState = cast(SchedulerState, self) @@ -4986,7 +4986,8 @@ def cancel_key(self, key, client, retries=5, force=False): for dts in list(ts._dependents): self.cancel_key(dts._key, client, force=force) logger.info("Scheduler cancels key %s. Force=%s", key, force) - self.report({"op": "cancelled-key", "key": key}) + if _report: + self.report({"op": "cancelled-key", "key": key}) clients = list(ts._who_wants) if force else [cs] for cs in clients: self.client_releases_keys(keys=[key], client=cs._client_key) diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index aed9f45f3b..19de60a977 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -80,7 +80,7 @@ def gather_regroup(i: int, all_futures: list[dict[int, Future]]) -> pd.DataFrame # forcibly cancel these futures now to allow memory to be released eagerly. # This is safe because we're only cancelling futures for this output partition, # and there's exactly one task for each output partition. - client.cancel(futures, force=True) + client.cancel(futures, force=True, _report=False) return _concat([s.obj for s in shards]) From a46c507aca434859193d727772096ef060f92f8f Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 16:39:42 -0600 Subject: [PATCH 15/17] REVERTME remove cancel logs --- distributed/scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8fe4ca10a1..35ebd080dd 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4958,7 +4958,7 @@ def remove_worker_from_events(): def stimulus_cancel(self, comm, keys=None, client=None, force=False, _report=True): """Stop execution on a list of keys""" - logger.info("Client %s requests to cancel %d keys", client, len(keys)) + # logger.info("Client %s requests to cancel %d keys", client, len(keys)) if client and _report: self.log_event( client, {"action": "cancel", "count": len(keys), "force": force} @@ -4985,7 +4985,7 @@ def cancel_key(self, key, client, retries=5, force=False, _report=True): if force or ts._who_wants == {cs}: # no one else wants this key for dts in list(ts._dependents): self.cancel_key(dts._key, client, force=force) - logger.info("Scheduler cancels key %s. Force=%s", key, force) + # logger.info("Scheduler cancels key %s. Force=%s", key, force) if _report: self.report({"op": "cancelled-key", "key": key}) clients = list(ts._who_wants) if force else [cs] From 78c337f56140762aad74e66e5f7414245994466f Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 17:02:44 -0600 Subject: [PATCH 16/17] REVERTME don't cancel Just want to see how it affects performance --- distributed/deploy/cluster.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 20a0a990f7..4bb174c859 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -94,9 +94,9 @@ async def _start(self): ) self._cluster_info.update(info) - self.periodic_callbacks["sync-cluster-info"] = PeriodicCallback( - self._sync_cluster_info, self._sync_interval * 1000 - ) + # self.periodic_callbacks["sync-cluster-info"] = PeriodicCallback( + # self._sync_cluster_info, self._sync_interval * 1000 + # ) for pc in self.periodic_callbacks.values(): pc.start() self.status = Status.running From aefe78a121efc001296244102a01355d42572eee Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 27 Oct 2021 17:14:51 -0600 Subject: [PATCH 17/17] HACK don't inform on deserialized keys Hoping this speeds up the transfer of Futures; makes no sense in general though. --- distributed/client.py | 2 +- distributed/shuffle/scatter.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 8bd551967d..90ea6bc35d 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -378,7 +378,7 @@ def __setstate__(self, state): c = Client.current(allow_global=False) except ValueError: c = get_client(address) - self.__init__(key, c) + self.__init__(key, c, inform=False) # HACK inform!! # TODO why was this here? Is it safe to remove? # c._send_to_scheduler( # { diff --git a/distributed/shuffle/scatter.py b/distributed/shuffle/scatter.py index 19de60a977..19ea92c6d8 100644 --- a/distributed/shuffle/scatter.py +++ b/distributed/shuffle/scatter.py @@ -75,12 +75,16 @@ def gather_regroup(i: int, all_futures: list[dict[int, Future]]) -> pd.DataFrame "Given Futures for all shards, select Futures for this output partition, gather them, and concat." client = get_client() futures = [fs[i] for fs in all_futures if i in fs] + for f in futures: + # HACK: we disabled informing on deserialized futures, so manually mark them as finished + if not f.done(): + f._state.finish() shards: list[QuickSizeof[pd.DataFrame]] = client.gather(futures, direct=True) - # Since every worker holds a reference to all futures until the very last task completes, - # forcibly cancel these futures now to allow memory to be released eagerly. - # This is safe because we're only cancelling futures for this output partition, - # and there's exactly one task for each output partition. - client.cancel(futures, force=True, _report=False) + # # Since every worker holds a reference to all futures until the very last task completes, + # # forcibly cancel these futures now to allow memory to be released eagerly. + # # This is safe because we're only cancelling futures for this output partition, + # # and there's exactly one task for each output partition. + # client.cancel(futures, force=True, _report=False) return _concat([s.obj for s in shards])