Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Async stream handlers #4474

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import threading
import traceback
from typing import Awaitable, Callable, Dict, Iterable
import uuid
import weakref
import warnings
Expand All @@ -29,7 +30,6 @@
from . import profile
from .system_monitor import SystemMonitor
from .utils import (
is_coroutine_function,
get_traceback,
truncate_exception,
shutting_down,
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
self,
handlers,
blocked_handlers=None,
stream_handlers=None,
stream_handlers: Dict[str, Callable[..., Awaitable]] = None,
connection_limit=512,
deserialize=True,
serializers=None,
Expand All @@ -147,7 +147,7 @@ def __init__(
"distributed.%s.blocked-handlers" % type(self).__name__.lower(), []
)
self.blocked_handlers = blocked_handlers
self.stream_handlers = {}
self.stream_handlers: Dict[str, Callable[..., Awaitable]] = {}
self.stream_handlers.update(stream_handlers or {})

self.id = type(self).__name__ + "-" + str(uuid.uuid4())
Expand Down Expand Up @@ -543,11 +543,12 @@ async def handle_comm(self, comm, shutting_down=shutting_down):
"Failed while closing connection to %r: %s", address, e
)

async def handle_stream(self, comm, extra=None, every_cycle=[]):
async def handle_stream(
self, comm, extra=None, every_cycle: Iterable[Callable[..., Awaitable]] = []
):
extra = extra or {}
logger.info("Starting established connection")

io_error = None
closed = False
try:
while not closed:
Expand All @@ -565,23 +566,17 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]):
closed = True
break
handler = self.stream_handlers[op]
if is_coroutine_function(handler):
self.loop.add_callback(handler, **merge(extra, msg))
await gen.sleep(0)
else:
handler(**merge(extra, msg))
self.loop.add_callback(handler, **merge(extra, msg))
await gen.sleep(0)
else:
logger.error("odd message %s", msg)
await asyncio.sleep(0)

for func in every_cycle:
if is_coroutine_function(func):
self.loop.add_callback(func)
else:
func()
self.loop.add_callback(func)

except (CommClosedError, EnvironmentError) as e:
io_error = e
pass
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down
29 changes: 18 additions & 11 deletions distributed/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
from contextlib import suppress
import logging
import threading
import typing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a personal esthetic preference, I find that using from typing import ... is always a nicer option. It does lead to collisions with from collections.abc import ..., which can however be worked around after you drop Python 3.6 support, add from __future__ import annotations everywhere, and you run a Python 3.9-compatible version of mypy (on python 3.7+).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally agree with you here @crusaderky. In this case I kept it this way because the thing I'm using is typing.TYPE_CHECKING, and I didn't really want that floating around the module namespace. No strong preference either way, though.

import weakref

from .core import CommClosedError
from .metrics import time
from .utils import sync, TimeoutError, parse_timedelta
from .protocol.serialize import to_serialize

if typing.TYPE_CHECKING:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found adding type annotations to the stream_handlers object was useful in implementing this (cf #2803), allowing type checkers to catch which handlers were async or sync. However, I needed this pretty unsightly hack to avoid a circular import, just to get the type name in the module scope. I'm not happy about it, and it could be removed. On the other hand, it's kind of nice for refactoring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does from __future__ import annotations help at all here? Admittedly that is Python 3.7+ only. Though we are planning to drop Python 3.6 soon ( #4390 )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I had to remove annotations in 4f1944b for 3.6 support. Though it didn't help that much, it mostly meant I didn't have to quote the type names below. The circular import problem remained.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakirkham it doesn't help - you still need those imports. from __future__ import annotations helps when you have a class accepting itself or returning itself in its methods, or when class A in a module accepts or returns class B which is declared afterwards in the same module.

from .scheduler import Scheduler
from .worker import Worker

logger = logging.getLogger(__name__)


class PubSubSchedulerExtension:
""" Extend Dask's scheduler with routes to handle PubSub machinery """

def __init__(self, scheduler):
def __init__(self, scheduler: "Scheduler"):
self.scheduler = scheduler
self.publishers = defaultdict(set)
self.subscribers = defaultdict(set)
Expand All @@ -35,7 +40,7 @@ def __init__(self, scheduler):

self.scheduler.extensions["pubsub"] = self

def add_publisher(self, comm=None, name=None, worker=None):
async def add_publisher(self, comm=None, name=None, worker=None):
logger.debug("Add publisher: %s %s", name, worker)
self.publishers[name].add(worker)
return {
Expand All @@ -44,7 +49,7 @@ def add_publisher(self, comm=None, name=None, worker=None):
and len(self.client_subscribers[name]) > 0,
}

def add_subscriber(self, comm=None, name=None, worker=None, client=None):
async def add_subscriber(self, comm=None, name=None, worker=None, client=None):
if worker:
logger.debug("Add worker subscriber: %s %s", name, worker)
self.subscribers[name].add(worker)
Expand All @@ -62,7 +67,7 @@ def add_subscriber(self, comm=None, name=None, worker=None, client=None):
)
self.client_subscribers[name].add(client)

def remove_publisher(self, comm=None, name=None, worker=None):
async def remove_publisher(self, comm=None, name=None, worker=None):
if worker in self.publishers[name]:
logger.debug("Remove publisher: %s %s", name, worker)
self.publishers[name].remove(worker)
Expand All @@ -71,7 +76,7 @@ def remove_publisher(self, comm=None, name=None, worker=None):
del self.subscribers[name]
del self.publishers[name]

def remove_subscriber(self, comm=None, name=None, worker=None, client=None):
async def remove_subscriber(self, comm=None, name=None, worker=None, client=None):
if worker:
logger.debug("Remove worker subscriber: %s %s", name, worker)
self.subscribers[name].remove(worker)
Expand Down Expand Up @@ -100,14 +105,14 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None):
del self.subscribers[name]
del self.publishers[name]

def handle_message(self, name=None, msg=None, worker=None, client=None):
async def handle_message(self, name=None, msg=None, worker=None, client=None):
for c in list(self.client_subscribers[name]):
try:
self.scheduler.client_comms[c].send(
{"op": "pubsub-msg", "name": name, "msg": msg}
)
except (KeyError, CommClosedError):
self.remove_subscriber(name=name, client=c)
await self.remove_subscriber(name=name, client=c)

if client:
for sub in self.subscribers[name]:
Expand All @@ -119,7 +124,7 @@ def handle_message(self, name=None, msg=None, worker=None, client=None):
class PubSubWorkerExtension:
""" Extend Dask's Worker with routes to handle PubSub machinery """

def __init__(self, worker):
def __init__(self, worker: "Worker"):
self.worker = worker
self.worker.stream_handlers.update(
{
Expand All @@ -136,15 +141,15 @@ def __init__(self, worker):

self.worker.extensions["pubsub"] = self # circular reference

def add_subscriber(self, name=None, address=None, **info):
async def add_subscriber(self, name=None, address=None, **info):
for pub in self.publishers[name]:
pub.subscribers[address] = info

def remove_subscriber(self, name=None, address=None):
async def remove_subscriber(self, name=None, address=None):
for pub in self.publishers[name]:
del pub.subscribers[address]

def publish_scheduler(self, name=None, publish=None):
async def publish_scheduler(self, name=None, publish=None):
self.publish_to_scheduler[name] = publish

async def handle_message(self, name=None, msg=None):
Expand Down Expand Up @@ -384,6 +389,8 @@ def __init__(self, name, worker=None, client=None):
pubsub = self.worker.extensions["pubsub"]
elif self.client:
pubsub = self.client.extensions["pubsub"]
else:
raise ValueError("Must include a worker or client")
self.loop.add_callback(pubsub.subscribers[name].add, self)

msg = {"op": "pubsub-add-subscriber", "name": self.name}
Expand Down
13 changes: 8 additions & 5 deletions distributed/queues.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import asyncio
from collections import defaultdict
import logging
import typing
import uuid

from dask.utils import stringify

from .client import Future, Client
from .utils import sync, thread_state
from .utils import sync, thread_state, parse_timedelta
from .worker import get_client
from .utils import parse_timedelta

if typing.TYPE_CHECKING:
from .scheduler import Scheduler

logger = logging.getLogger(__name__)

Expand All @@ -25,7 +28,7 @@ class QueueExtension:
* queue_size
"""

def __init__(self, scheduler):
def __init__(self, scheduler: "Scheduler"):
self.scheduler = scheduler
self.queues = dict()
self.client_refcount = dict()
Expand Down Expand Up @@ -54,7 +57,7 @@ def create(self, comm=None, name=None, client=None, maxsize=0):
else:
self.client_refcount[name] += 1

def release(self, comm=None, name=None, client=None):
async def release(self, comm=None, name=None, client=None):
if name not in self.queues:
return

Expand All @@ -78,7 +81,7 @@ async def put(
record = {"type": "msgpack", "value": data}
await asyncio.wait_for(self.queues[name].put(record), timeout=timeout)

def future_release(self, name=None, key=None, client=None):
async def future_release(self, name=None, key=None, client=None):
self.future_refcount[name, key] -= 1
if self.future_refcount[name, key] == 0:
self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name)
Expand Down
5 changes: 3 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import suppress
from datetime import timedelta
from functools import partial
from typing import Awaitable, Callable, Dict
import inspect
import itertools
import json
Expand Down Expand Up @@ -3039,7 +3040,7 @@ def __init__(
self.event_counts = defaultdict(int)
self.worker_plugins = []

worker_handlers = {
worker_handlers: Dict[str, Callable[..., Awaitable]] = {
"task-finished": self.handle_task_finished,
"task-erred": self.handle_task_erred,
"release": self.handle_release_data,
Expand All @@ -3052,7 +3053,7 @@ def __init__(
"log-event": self.log_worker_event,
}

client_handlers = {
client_handlers: Dict[str, Callable[..., Awaitable]] = {
"update-graph": self.update_graph,
"update-graph-hlg": self.update_graph_hlg,
"client-desires-keys": self.client_desires_keys,
Expand Down
6 changes: 5 additions & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from math import log2
from time import time
import typing

from tornado.ioloop import PeriodicCallback

Expand All @@ -11,6 +12,9 @@
from .diagnostics.plugin import SchedulerPlugin
from .utils import log_errors, parse_timedelta

if typing.TYPE_CHECKING:
from .scheduler import Scheduler

from tlz import topk

LATENCY = 10e-3
Expand All @@ -22,7 +26,7 @@


class WorkStealing(SchedulerPlugin):
def __init__(self, scheduler):
def __init__(self, scheduler: "Scheduler"):
self.scheduler = scheduler
# { level: { task states } }
self.stealable_all = [set() for i in range(15)]
Expand Down
16 changes: 16 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pkgutil
import base64
import tblib.pickling_support
from typing import Awaitable, Callable, TypeVar
import xml.etree.ElementTree

try:
Expand Down Expand Up @@ -289,6 +290,21 @@ def quiet():
return results


T = TypeVar("T")


def asyncify(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some places I've used an async wrapper to have a lighter footprint on the API, but more functions could just be made async in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I almost raised a question about this until I looked at how these methods were used. I agree that making them sync by default is probably the right choice. This seemed like a good solution to me.

"""
Wrap a synchronous function in an async one
"""

@functools.wraps(func)
async def wrapped(*args, **kwargs):
return func(*args, **kwargs)

return wrapped


def sync(loop, func, *args, callback_timeout=None, **kwargs):
"""
Run coroutine in loop running in separate thread.
Expand Down
6 changes: 5 additions & 1 deletion distributed/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from contextlib import suppress
import logging
import typing
import uuid

from tlz import merge
Expand All @@ -11,6 +12,9 @@
from .utils import log_errors, TimeoutError, parse_timedelta
from .worker import get_client

if typing.TYPE_CHECKING:
from .scheduler import Scheduler

logger = logging.getLogger(__name__)


Expand All @@ -24,7 +28,7 @@ class VariableExtension:
* variable-delete
"""

def __init__(self, scheduler):
def __init__(self, scheduler: "Scheduler"):
self.scheduler = scheduler
self.variables = dict()
self.waiting = defaultdict(set)
Expand Down
18 changes: 12 additions & 6 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import random
import threading
import sys
from typing import Awaitable, Callable, Dict
import uuid
import warnings
import weakref
Expand Down Expand Up @@ -47,6 +48,7 @@
from .sizeof import safe_sizeof as sizeof
from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede
from .utils import (
asyncify,
get_ip,
typename,
has_arg,
Expand Down Expand Up @@ -676,12 +678,12 @@ def __init__(
"plugin-add": self.plugin_add,
}

stream_handlers = {
stream_handlers: Dict[str, Callable[..., Awaitable]] = {
"close": self.close,
"compute-task": self.add_task,
"release-task": partial(self.release_key, report=False),
"delete-data": self.delete_data,
"steal-request": self.steal_request,
"compute-task": asyncify(self.add_task),
"release-task": asyncify(partial(self.release_key, report=False)),
"delete-data": asyncify(self.delete_data),
"steal-request": asyncify(self.steal_request),
}

super().__init__(
Expand Down Expand Up @@ -985,7 +987,11 @@ async def heartbeat(self):
async def handle_scheduler(self, comm):
try:
await self.handle_stream(
comm, every_cycle=[self.ensure_communicating, self.ensure_computing]
comm,
every_cycle=[
asyncify(self.ensure_communicating),
self.ensure_computing,
],
)
except Exception as e:
logger.exception(e)
Expand Down