Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Client.get_dataset to always create Futures attached to itself #3729

Merged
merged 24 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
207c27b
Python 3.6+ syntax
crusaderky Apr 11, 2020
02b1079
Code polish
crusaderky Apr 14, 2020
026a602
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 14, 2020
e07b9f8
Revert
crusaderky Apr 14, 2020
5d2566c
Polish
crusaderky Apr 14, 2020
a5dc1be
Revert "Polish"
crusaderky Apr 14, 2020
f2b6e1f
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 15, 2020
f86126a
tests
crusaderky Apr 15, 2020
7ef5832
Merge branch 'master' into get_dataset_async
crusaderky Apr 16, 2020
77a0d8a
revert
crusaderky Apr 16, 2020
de6b457
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 17, 2020
9d6c8c0
revert
crusaderky Apr 17, 2020
6515c1b
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 18, 2020
d8fcdc9
Merge branch 'master' into get_dataset_async
crusaderky Apr 20, 2020
21522df
xfail
crusaderky Apr 20, 2020
4b55b31
Better async functions
crusaderky Apr 20, 2020
fb8f777
Use contextvars to deserialize Future
crusaderky Apr 20, 2020
804b537
Merge branch 'master' into get_dataset_async
crusaderky Apr 21, 2020
2d594f8
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky Apr 27, 2020
042eddf
Redesign
crusaderky Apr 27, 2020
621573d
Tweaks
crusaderky Apr 27, 2020
77fa36b
Merge remote-tracking branch 'upstream/master' into get_dataset_async
crusaderky May 5, 2020
e70303a
docstrings
crusaderky May 5, 2020
f0f8c6e
Merge branch 'master' into get_dataset_async
crusaderky May 7, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci-windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
activate-environment: dask-distributed
auto-activate-base: false

- name: Install contextvars
shell: bash -l {0}
run: |
if [[ "${{ matrix.python-version }}" = "3.6" ]]; then
conda install -c conda-forge contextvars
fi

- name: Install tornado
shell: bash -l {0}
run: |
Expand Down
4 changes: 4 additions & 0 deletions continuous_integration/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ conda create -n dask-distributed -c conda-forge -c defaults \

source activate dask-distributed

if [[ $PYTHON == 3.6 ]]; then
conda install -c conda-forge -c defaults contextvars
fi

# stacktrace is not currently avaiable for Python 3.8.
# Remove the version check block below when it is avaiable.
if [[ $PYTHON != 3.8 ]]; then
Expand Down
73 changes: 58 additions & 15 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures._base import DoneAndNotDoneFutures
from contextlib import contextmanager
from contextvars import ContextVar
import copy
import errno
from functools import partial
Expand Down Expand Up @@ -89,6 +90,7 @@
_global_clients = weakref.WeakValueDictionary()
_global_client_index = [0]

_current_client = ContextVar("_current_client", default=None)

DEFAULT_EXTENSIONS = [PubSubClientExtension]

Expand Down Expand Up @@ -162,7 +164,7 @@ def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
tkey = tokey(key)
self.client = client or _get_global_client()
self.client = client or Client.current()
self.client._inc_ref(tkey)
self._generation = self.client.generation

Expand Down Expand Up @@ -353,11 +355,14 @@ def release(self, _in_destructor=False):
pass # Shutting down, add_callback may be None

def __getstate__(self):
return (self.key, self.client.scheduler.address)
return self.key, self.client.scheduler.address

def __setstate__(self, state):
key, address = state
c = get_client(address)
try:
c = Client.current(allow_global=False)
except ValueError:
c = get_client(address)
Future.__init__(self, key, c)
c._send_to_scheduler(
{
Expand Down Expand Up @@ -727,10 +732,42 @@ def __init__(

ReplayExceptionClient(self)

@contextmanager
def as_current(self):
"""Thread-local, Task-local context manager that causes the Client.current class
method to return self. This is used when a method of Client needs to propagate a
reference to self deep into the stack through generic methods that shouldn't be
aware of this class.
jcrist marked this conversation as resolved.
Show resolved Hide resolved
"""
# Python 3.6; contextvars are thread-local but not Task-local.
# We can still detect a race condition.
if sys.version_info < (3, 7) and _current_client.get() not in (self, None):
raise RuntimeError(
"Detected race condition where get_dataset() is invoked in "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error message still mentions get_dataset, should probably be made a bit more generic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"parallel by multiple asynchronous clients. "
"Please upgrade to Python 3.7+."
)

tok = _current_client.set(self)
try:
yield
finally:
_current_client.reset(tok)

@classmethod
def current(cls):
""" Return global client if one exists, otherwise raise ValueError """
return default_client()
def current(cls, allow_global=True):
"""When running within the context of `as_client`, return the context-local
current client. Otherwise, return the latest initialised Client.
If no Client instances exist, raise ValueError.
If allow_global is set to False, raise ValueError if running outside of the
`as_client` context manager.
"""
out = _current_client.get()
if out:
return out
if allow_global:
return default_client()
raise ValueError("Not running inside the `as_current` context manager")

@property
def asynchronous(self):
Expand Down Expand Up @@ -2178,8 +2215,7 @@ def retry(self, futures, asynchronous=None):
"""
return self.sync(self._retry, futures, asynchronous=asynchronous)

@gen.coroutine
def _publish_dataset(self, *args, name=None, **kwargs):
async def _publish_dataset(self, *args, name=None, **kwargs):
with log_errors():
coroutines = []

Expand All @@ -2205,7 +2241,7 @@ def add_coro(name, data):
for name, data in kwargs.items():
add_coro(name, data)

yield coroutines
await asyncio.gather(*coroutines)

def publish_dataset(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -2285,13 +2321,12 @@ def list_datasets(self, **kwargs):
return self.sync(self.scheduler.publish_list, **kwargs)

async def _get_dataset(self, name):
out = await self.scheduler.publish_get(name=name, client=self.id)
if out is None:
raise KeyError("Dataset '%s' not found" % name)
with self.as_current():
out = await self.scheduler.publish_get(name=name, client=self.id)

with temp_default_client(self):
data = out["data"]
return data
if out is None:
raise KeyError(f"Dataset '{name}' not found")
return out["data"]

def get_dataset(self, name, **kwargs):
"""
Expand Down Expand Up @@ -4697,6 +4732,14 @@ def __exit__(self, typ, value, traceback):
def temp_default_client(c):
""" Set the default client for the duration of the context

.. note::
This function should be used exclusively for unit testing the default client
functionality. In all other cases, please use ``Client.as_current`` instead.

.. note::
Unlike ``Client.as_current``, this context manager is neither thread-local nor
task-local.

Parameters
----------
c : Client
Expand Down
8 changes: 6 additions & 2 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid

from .client import _get_global_client
from .client import Client
from .utils import log_errors, TimeoutError
from .worker import get_worker

Expand Down Expand Up @@ -93,7 +93,11 @@ class Lock:
"""

def __init__(self, name=None, client=None):
self.client = client or _get_global_client() or get_worker().client
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.name = name or "lock-" + uuid.uuid4().hex
self.id = uuid.uuid4().hex
self._locked = False
Expand Down
59 changes: 49 additions & 10 deletions distributed/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
class PublishExtension:
""" An extension for the scheduler to manage collections

* publish-list
* publish-put
* publish-get
* publish-delete
* publish_list
* publish_put
* publish_get
* publish_delete
"""

def __init__(self, scheduler):
Expand Down Expand Up @@ -59,21 +59,60 @@ class Datasets(MutableMapping):

"""

__slots__ = ("_client",)

def __init__(self, client):
self.__client = client
self._client = client

def __getitem__(self, key):
return self.__client.get_dataset(key)
# When client is asynchronous, it returns a coroutine
return self._client.get_dataset(key)

def __setitem__(self, key, value):
self.__client.publish_dataset(value, name=key)
if self._client.asynchronous:
# 'await obj[key] = value' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'client.datasets[name] = value' when client is "
"asynchronous; please use 'client.publish_dataset(name=value)' instead"
)
self._client.publish_dataset(value, name=key)

def __delitem__(self, key):
self.__client.unpublish_dataset(key)
if self._client.asynchronous:
# 'await del obj[key]' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'del client.datasets[name]' when client is asynchronous; "
"please use 'client.unpublish_dataset(name)' instead"
)
return self._client.unpublish_dataset(key)

def __iter__(self):
for key in self.__client.list_datasets():
if self._client.asynchronous:
raise TypeError(
"Can't invoke iter() or 'for' on client.datasets when client is "
"asynchronous; use 'async for' instead"
)
for key in self._client.list_datasets():
yield key

def __aiter__(self):
if not self._client.asynchronous:
raise TypeError(
"Can't invoke 'async for' on client.datasets when client is "
"synchronous; use iter() or 'for' instead"
)

async def _():
for key in await self._client.list_datasets():
yield key

return _()

def __len__(self):
return len(self.__client.list_datasets())
if self._client.asynchronous:
# 'await len(obj)' is not supported by Python as of 3.8
raise TypeError(
"Can't use 'len(client.datasets)' when client is asynchronous; "
"please use 'len(await client.list_datasets())' instead"
)
return len(self._client.list_datasets())
6 changes: 3 additions & 3 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid

from .client import Future, _get_global_client, Client
from .client import Future, Client
from .utils import tokey, sync, thread_state
from .worker import get_client

Expand Down Expand Up @@ -148,7 +148,7 @@ class Queue:
not given, a random name will be generated.
client: Client (optional)
Client used for communication with the scheduler. Defaults to the
value of ``_get_global_client()``.
value of ``Client.current()``.
maxsize: int (optional)
Number of items allowed in the queue. If 0 (the default), the queue
size is unbounded.
Expand All @@ -167,7 +167,7 @@ class Queue:
"""

def __init__(self, name=None, client=None, maxsize=0):
self.client = client or _get_global_client()
self.client = client or Client.current()
self.name = name or "queue-" + uuid.uuid4().hex
self._event_started = asyncio.Event()
if self.client.asynchronous or getattr(
Expand Down
Loading