From 285893037fe9eac83f363611b4799168aabb3992 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 20 Sep 2023 16:16:25 +0200 Subject: [PATCH 1/3] Reduce memory usage during culling for shuffling and merge (#8197) --- distributed/shuffle/_merge.py | 14 ++++++-------- distributed/shuffle/_shuffle.py | 7 +++++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 75e72e884d..7d69f40d14 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -1,7 +1,6 @@ # mypy: ignore-errors from __future__ import annotations -from collections import defaultdict from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any @@ -243,15 +242,14 @@ def _cull_dependencies( all input partitions. This method does not require graph materialization. """ - deps = defaultdict(set) + deps = {} parts_out = parts_out or self._keys_to_parts(keys) + keys = {(self.name_input_left, i) for i in range(self.npartitions)} + keys |= {(self.name_input_right, i) for i in range(self.npartitions)} + # Protect against mutations later on with frozenset + keys = frozenset(keys) for part in parts_out: - deps[(self.name, part)] |= { - (self.name_input_left, i) for i in range(self.npartitions) - } - deps[(self.name, part)] |= { - (self.name_input_right, i) for i in range(self.npartitions) - } + deps[(self.name, part)] = keys return deps def _keys_to_parts(self, keys: Iterable[str]) -> set[str]: diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 43fe87b4c5..6c80dc26d3 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -227,8 +227,11 @@ def cull( parameter. """ parts_out = self._keys_to_parts(keys) - input_parts = {(self.name_input, i) for i in range(self.npartitions_input)} - culled_deps = {(self.name, part): input_parts.copy() for part in parts_out} + # Protect against mutations later on with frozenset + input_parts = frozenset( + {(self.name_input, i) for i in range(self.npartitions_input)} + ) + culled_deps = {(self.name, part): input_parts for part in parts_out} if parts_out != set(self.parts_out): culled_layer = self._cull(parts_out) From e2ae9e694b89950766d91158742111115393ad3f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 22 Sep 2023 15:25:07 +0100 Subject: [PATCH 2/3] Off-by-one in the retries count in KilledWorker (#8203) --- distributed/scheduler.py | 4 ++-- distributed/tests/test_scheduler.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 98297edcba..31b939d2e8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -8449,8 +8449,8 @@ def allowed_failures(self) -> int: def __str__(self) -> str: return ( - f"Attempted to run task {self.task} on {self.allowed_failures} different " - "workers, but all those workers died while running it. " + f"Attempted to run task {self.task} on {self.allowed_failures + 1} " + "different workers, but all those workers died while running it. " f"The last worker that attempt to run the task was {self.last_worker.address}. " "Inspecting worker logs is often a good next step to diagnose what went wrong. " "For more information see https://distributed.dask.org/en/stable/killed.html." diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 19b19ce937..8d5f68f4b0 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -4244,8 +4244,8 @@ async def test_KilledWorker_informative_message(s, a, b): with pytest.raises(KilledWorker) as excinfo: raise ex msg = str(excinfo.value) - assert "Attempted to run task foo-bar" in msg - assert str(s.allowed_failures) in msg + assert "Attempted to run task foo-bar on 667 different workers" in msg + assert a.address in msg assert "worker logs" in msg assert "https://distributed.dask.org/en/stable/killed.html" in msg From b6333df384ba4f21c090195772f20dce989435a6 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Fri, 22 Sep 2023 15:26:26 +0100 Subject: [PATCH 3/3] Centralize and type no_default (#8171) --- distributed/client.py | 14 +++++++------- distributed/deploy/cluster.py | 3 --- distributed/scheduler.py | 4 ++-- distributed/tests/test_utils.py | 6 ++++++ distributed/utils.py | 6 +++--- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index d02327ca9e..1422394b97 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -35,6 +35,7 @@ from dask.core import flatten, validate_key from dask.highlevelgraph import HighLevelGraph from dask.optimization import SubgraphCallable +from dask.typing import no_default from dask.utils import ( apply, ensure_dict, @@ -101,7 +102,6 @@ import_term, is_python_shutting_down, log_errors, - no_default, sync, thread_state, ) @@ -854,7 +854,7 @@ def __init__( connection_limit=512, **kwargs, ): - if timeout == no_default: + if timeout is no_default: timeout = dask.config.get("distributed.comm.timeouts.connect") if timeout is not None: timeout = parse_timedelta(timeout, "s") @@ -1248,7 +1248,7 @@ async def _start(self, timeout=no_default, **kwargs): await self.rpc.start() - if timeout == no_default: + if timeout is no_default: timeout = self._timeout if timeout is not None: timeout = parse_timedelta(timeout, "s") @@ -1753,7 +1753,7 @@ def close(self, timeout=no_default): -------- Client.restart """ - if timeout == no_default: + if timeout is no_default: timeout = self._timeout * 2 # XXX handling of self.status here is not thread-safe if self.status in ["closed", "newly-created"]: @@ -2399,7 +2399,7 @@ async def _scatter( timeout=no_default, hash=True, ): - if timeout == no_default: + if timeout is no_default: timeout = self._timeout if isinstance(workers, (str, Number)): workers = [workers] @@ -2588,7 +2588,7 @@ def scatter( -------- Client.gather : Gather data back to local process """ - if timeout == no_default: + if timeout is no_default: timeout = self._timeout if isinstance(data, pyQueue) or isinstance(data, Iterator): raise TypeError( @@ -3577,7 +3577,7 @@ def persist( return result async def _restart(self, timeout=no_default, wait_for_workers=True): - if timeout == no_default: + if timeout is no_default: timeout = self._timeout * 4 if timeout is not None: timeout = parse_timedelta(timeout, "s") diff --git a/distributed/deploy/cluster.py b/distributed/deploy/cluster.py index 514b97458b..6a5d02382e 100644 --- a/distributed/deploy/cluster.py +++ b/distributed/deploy/cluster.py @@ -33,9 +33,6 @@ logger = logging.getLogger(__name__) -no_default = "__no_default__" - - class Cluster(SyncMethodMixin): """Superclass for cluster objects diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 31b939d2e8..63f5092b86 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -54,6 +54,7 @@ import dask import dask.utils from dask.core import get_deps, validate_key +from dask.typing import no_default from dask.utils import ( format_bytes, format_time, @@ -119,7 +120,6 @@ get_fileno_limit, key_split_group, log_errors, - no_default, offload, recursive_to_dict, wait_for, @@ -7419,7 +7419,7 @@ def get_metadata(self, keys: list[str], default: Any = no_default) -> Any: metadata = metadata[key] return metadata except KeyError: - if default != no_default: + if default is not no_default: return default else: raise diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 1be181422a..39378d3ba3 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -700,6 +700,12 @@ async def set_var(v: str) -> None: await asyncio.gather(set_var("foo"), set_var("bar")) +def test_no_default_deprecated(): + with pytest.warns(FutureWarning, match="no_default is deprecated"): + from distributed.utils import no_default + assert no_default is dask.typing.no_default + + def test_iscoroutinefunction_unhashable_input(): # Ensure iscoroutinefunction can handle unhashable callables assert not iscoroutinefunction(_UnhashableCallable()) diff --git a/distributed/utils.py b/distributed/utils.py index 45f5e48d82..093fda233a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -95,8 +95,6 @@ P = ParamSpec("P") T = TypeVar("T") -no_default = "__no_default__" - _forkserver_preload_set = False @@ -1604,7 +1602,9 @@ def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list return addresses -_deprecations: dict[str, str] = {} +_deprecations = { + "no_default": "dask.typing.no_default", +} def __getattr__(name):