Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid calls to make_current() and make_clear() by using asyncio.run in LoopRunner #7467

Merged
merged 12 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions distributed/actor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import abc
import asyncio
import functools
import sys
import threading
from collections.abc import Awaitable, Generator
from dataclasses import dataclass
Expand All @@ -14,41 +12,13 @@

from distributed.client import Future
from distributed.protocol import to_serialize
from distributed.utils import iscoroutinefunction, sync, thread_state
from distributed.utils import LateLoopEvent, iscoroutinefunction, sync, thread_state
from distributed.utils_comm import WrappedKey
from distributed.worker import get_client, get_worker

_T = TypeVar("_T")


if sys.version_info >= (3, 10):
from asyncio import Event as _LateLoopEvent
else:
# In python 3.10 asyncio.Lock and other primitives no longer support
# passing a loop kwarg to bind to a loop running in another thread
# e.g. calling from Client(asynchronous=False). Instead the loop is bound
# as late as possible: when calling any methods that wait on or wake
# Future instances. See: https://bugs.python.org/issue42392
class _LateLoopEvent:
def __init__(self) -> None:
self._event: asyncio.Event | None = None

def set(self) -> None:
if self._event is None:
self._event = asyncio.Event()

self._event.set()

def is_set(self) -> bool:
return self._event is not None and self._event.is_set()

async def wait(self) -> bool:
if self._event is None:
self._event = asyncio.Event()

return await self._event.wait()


class Actor(WrappedKey):
"""Controls an object on a remote worker

Expand Down Expand Up @@ -322,7 +292,7 @@ def unwrap(self) -> NoReturn:
class ActorFuture(BaseActorFuture[_T]):
def __init__(self, io_loop: IOLoop):
self._io_loop = io_loop
self._event = _LateLoopEvent()
self._event = LateLoopEvent()
self._out: _Error | _OK[_T] | None = None

def __await__(self) -> Generator[object, None, _T]:
Expand Down
17 changes: 12 additions & 5 deletions distributed/deploy/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,16 +206,17 @@ async def _close(self):

self.status = Status.closed

def close(self, timeout=None):
def close(self, timeout: float | None = None) -> Any:
# If the cluster is already closed, we're already done
if self.status == Status.closed:
if self.asynchronous:
return NoOpAwaitable()
else:
return
return None

with suppress(RuntimeError): # loop closed during process shutdown
try:
return self.sync(self._close, callback_timeout=timeout)
except RuntimeError: # loop closed during process shutdown
return None

def __del__(self, _warn=warnings.warn):
if getattr(self, "status", Status.closed) != Status.closed:
Expand Down Expand Up @@ -519,10 +520,16 @@ def _ipython_display_(self, **kwargs):
display(mimebundle, raw=True)

def __enter__(self):
if self.asynchronous:
raise TypeError(
"Used 'with' with asynchronous class; please use 'async with'"
)

return self.sync(self.__aenter__)

def __exit__(self, exc_type, exc_value, traceback):
return self.sync(self.__aexit__, exc_type, exc_value, traceback)
aw = self.close()
assert aw is None, aw
crusaderky marked this conversation as resolved.
Show resolved Hide resolved

def __await__(self):
return self
Expand Down
11 changes: 7 additions & 4 deletions distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,15 @@ def __init__(
self.sync(self._correct_state)
except Exception:
self.sync(self.close)
self._loop_runner.stop()
raise

def close(self, timeout: float | None = None) -> Awaitable[None] | None:
aw = super().close(timeout)
if not self.asynchronous:
self._loop_runner.stop()
return aw

async def _start(self):
while self.status == Status.starting:
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -472,10 +479,6 @@ async def __aenter__(self):
assert self.status == Status.running
return self

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._loop_runner.stop()

def _threads_per_worker(self) -> int:
"""Return the number of threads per worker for new workers"""
if not self.new_spec: # pragma: no cover
Expand Down
5 changes: 1 addition & 4 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,6 @@ async def test_no_more_workers_than_tasks():
assert len(cluster.scheduler.workers) <= 1


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_basic_no_loop(cleanup):
loop = None
try:
Expand All @@ -293,8 +291,7 @@ def test_basic_no_loop(cleanup):
assert future.result() == 2
loop = cluster.loop
finally:
if loop is not None:
loop.add_callback(loop.stop)
assert loop is None or not loop.asyncio_loop.is_running()


@pytest.mark.flaky(condition=LINUX, reruns=10, reruns_delay=5)
Expand Down
10 changes: 10 additions & 0 deletions distributed/deploy/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,13 @@ def test_exponential_backoff():
assert _exponential_backoff(5, 1.5, 3, 20) == 20
# avoid overflow
assert _exponential_backoff(1000, 1.5, 3, 20) == 20


@gen_test()
async def test_sync_context_manager_used_with_async_cluster():
async with Cluster(asynchronous=True, name="A") as cluster:
with pytest.raises(
TypeError,
match=r"Used 'with' with asynchronous class; please use 'async with'",
), cluster:
pass
2 changes: 0 additions & 2 deletions distributed/deploy/tests/test_spec_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_spec_sync(loop):
assert result == 11


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_loop_started_in_constructor(cleanup):
# test that SpecCluster.__init__ starts a loop in another thread
cluster = SpecCluster(worker_spec, scheduler=scheduler, loop=None)
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
wait,
worker_client,
)
from distributed.actor import _LateLoopEvent
from distributed.metrics import time
from distributed.utils import LateLoopEvent
from distributed.utils_test import cluster, double, gen_cluster, inc
from distributed.worker import get_worker

Expand Down Expand Up @@ -263,7 +263,7 @@ def test_sync(client):
def test_timeout(client):
class Waiter:
def __init__(self):
self.event = _LateLoopEvent()
self.event = LateLoopEvent()

async def set(self):
self.event.set()
Expand Down Expand Up @@ -555,7 +555,7 @@ def sleep(self, time):
async def test_waiter(c, s, a, b):
class Waiter:
def __init__(self):
self.event = _LateLoopEvent()
self.event = LateLoopEvent()

async def set(self):
self.event.set()
Expand Down
55 changes: 7 additions & 48 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,7 @@
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
from distributed.shuffle import check_minimal_arrow_version
from distributed.sizeof import sizeof
from distributed.utils import (
NoOpAwaitable,
get_mp_context,
is_valid_xml,
open_port,
sync,
tmp_text,
)
from distributed.utils import get_mp_context, is_valid_xml, open_port, sync, tmp_text
from distributed.utils_test import (
NO_AMM,
BlockedGatherDep,
Expand Down Expand Up @@ -2106,27 +2099,8 @@ async def test_multi_client(s, a, b):
await asyncio.sleep(0.01)


@contextmanager
def _pristine_loop():
IOLoop.clear_instance()
IOLoop.clear_current()
loop = IOLoop()
loop.make_current()
assert IOLoop.current() is loop
try:
yield loop
finally:
try:
loop.close(all_fds=True)
except (KeyError, ValueError):
pass
IOLoop.clear_instance()
IOLoop.clear_current()


def long_running_client_connection(address):
with _pristine_loop():
c = Client(address)
with Client(address, loop=None) as c:
x = c.submit(lambda x: x + 1, 10)
x.result()
sleep(100)
Expand Down Expand Up @@ -2779,8 +2753,6 @@ async def test_startup_close_startup(s, a, b):
pass


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_startup_close_startup_sync(loop):
with cluster() as (s, [a, b]):
with Client(s["address"], loop=loop) as c:
Expand Down Expand Up @@ -5652,23 +5624,12 @@ async def test_future_auto_inform(c, s, a, b):
await asyncio.sleep(0.01)


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:clear_current is deprecated:DeprecationWarning")
def test_client_async_before_loop_starts(cleanup):
async def close():
async with client:
pass

with _pristine_loop() as loop:
with pytest.warns(
DeprecationWarning,
match=r"Constructing LoopRunner\(loop=loop\) without a running loop is deprecated",
):
client = Client(asynchronous=True, loop=loop)
assert client.asynchronous
assert isinstance(client.close(), NoOpAwaitable)
loop.run_sync(close) # TODO: client.close() does not unset global client
with pytest.raises(
RuntimeError,
match=r"Constructing LoopRunner\(asynchronous=True\) without a running loop is not supported",
):
client = Client(asynchronous=True, loop=None)


@pytest.mark.slow
Expand Down Expand Up @@ -7110,8 +7071,6 @@ async def test_workers_collection_restriction(c, s, a, b):
assert a.data and not b.data


@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_get_client_functions_spawn_clusters(c, s, a):
# see gh4565
Expand Down
6 changes: 0 additions & 6 deletions distributed/tests/test_client_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import contextlib

import pytest

from distributed import Client, LocalCluster
from distributed.utils import LoopRunner

Expand All @@ -29,16 +27,12 @@ def _check_cluster_and_client_loop(loop):


# Test if Client stops LoopRunner on close.
@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_close_loop_sync_start_new_loop(cleanup):
with _check_loop_runner():
_check_cluster_and_client_loop(loop=None)


# Test if Client stops LoopRunner on close.
@pytest.mark.filterwarnings("ignore:There is no current event loop:DeprecationWarning")
@pytest.mark.filterwarnings("ignore:make_current is deprecated:DeprecationWarning")
def test_close_loop_sync_use_running_loop(cleanup):
with _check_loop_runner():
# Start own loop or use current thread's one.
Expand Down
Loading