Skip to content

Commit

Permalink
Centralize and type no_default (#8171)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Sep 22, 2023
1 parent e2ae9e6 commit b6333df
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 15 deletions.
14 changes: 7 additions & 7 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dask.core import flatten, validate_key
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import SubgraphCallable
from dask.typing import no_default
from dask.utils import (
apply,
ensure_dict,
Expand Down Expand Up @@ -101,7 +102,6 @@
import_term,
is_python_shutting_down,
log_errors,
no_default,
sync,
thread_state,
)
Expand Down Expand Up @@ -854,7 +854,7 @@ def __init__(
connection_limit=512,
**kwargs,
):
if timeout == no_default:
if timeout is no_default:
timeout = dask.config.get("distributed.comm.timeouts.connect")
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
Expand Down Expand Up @@ -1248,7 +1248,7 @@ async def _start(self, timeout=no_default, **kwargs):

await self.rpc.start()

if timeout == no_default:
if timeout is no_default:
timeout = self._timeout
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
Expand Down Expand Up @@ -1753,7 +1753,7 @@ def close(self, timeout=no_default):
--------
Client.restart
"""
if timeout == no_default:
if timeout is no_default:
timeout = self._timeout * 2
# XXX handling of self.status here is not thread-safe
if self.status in ["closed", "newly-created"]:
Expand Down Expand Up @@ -2399,7 +2399,7 @@ async def _scatter(
timeout=no_default,
hash=True,
):
if timeout == no_default:
if timeout is no_default:
timeout = self._timeout
if isinstance(workers, (str, Number)):
workers = [workers]
Expand Down Expand Up @@ -2588,7 +2588,7 @@ def scatter(
--------
Client.gather : Gather data back to local process
"""
if timeout == no_default:
if timeout is no_default:
timeout = self._timeout
if isinstance(data, pyQueue) or isinstance(data, Iterator):
raise TypeError(
Expand Down Expand Up @@ -3577,7 +3577,7 @@ def persist(
return result

async def _restart(self, timeout=no_default, wait_for_workers=True):
if timeout == no_default:
if timeout is no_default:
timeout = self._timeout * 4
if timeout is not None:
timeout = parse_timedelta(timeout, "s")
Expand Down
3 changes: 0 additions & 3 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
logger = logging.getLogger(__name__)


no_default = "__no_default__"


class Cluster(SyncMethodMixin):
"""Superclass for cluster objects
Expand Down
4 changes: 2 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import dask
import dask.utils
from dask.core import get_deps, validate_key
from dask.typing import no_default
from dask.utils import (
format_bytes,
format_time,
Expand Down Expand Up @@ -119,7 +120,6 @@
get_fileno_limit,
key_split_group,
log_errors,
no_default,
offload,
recursive_to_dict,
wait_for,
Expand Down Expand Up @@ -7419,7 +7419,7 @@ def get_metadata(self, keys: list[str], default: Any = no_default) -> Any:
metadata = metadata[key]
return metadata
except KeyError:
if default != no_default:
if default is not no_default:
return default
else:
raise
Expand Down
6 changes: 6 additions & 0 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,12 @@ async def set_var(v: str) -> None:
await asyncio.gather(set_var("foo"), set_var("bar"))


def test_no_default_deprecated():
with pytest.warns(FutureWarning, match="no_default is deprecated"):
from distributed.utils import no_default
assert no_default is dask.typing.no_default


def test_iscoroutinefunction_unhashable_input():
# Ensure iscoroutinefunction can handle unhashable callables
assert not iscoroutinefunction(_UnhashableCallable())
Expand Down
6 changes: 3 additions & 3 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@
P = ParamSpec("P")
T = TypeVar("T")

no_default = "__no_default__"

_forkserver_preload_set = False


Expand Down Expand Up @@ -1604,7 +1602,9 @@ def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list
return addresses


_deprecations: dict[str, str] = {}
_deprecations = {
"no_default": "dask.typing.no_default",
}


def __getattr__(name):
Expand Down

0 comments on commit b6333df

Please sign in to comment.