diff --git a/distributed/core.py b/distributed/core.py index 15205f4f72c..dcc2997d13b 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -7,6 +7,7 @@ import logging import threading import traceback +from typing import Awaitable, Callable, Dict, Iterable import uuid import weakref import warnings @@ -29,7 +30,6 @@ from . import profile from .system_monitor import SystemMonitor from .utils import ( - is_coroutine_function, get_traceback, truncate_exception, shutting_down, @@ -127,7 +127,7 @@ def __init__( self, handlers, blocked_handlers=None, - stream_handlers=None, + stream_handlers: Dict[str, Callable[..., Awaitable]] = None, connection_limit=512, deserialize=True, serializers=None, @@ -147,7 +147,7 @@ def __init__( "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] ) self.blocked_handlers = blocked_handlers - self.stream_handlers = {} + self.stream_handlers: Dict[str, Callable[..., Awaitable]] = {} self.stream_handlers.update(stream_handlers or {}) self.id = type(self).__name__ + "-" + str(uuid.uuid4()) @@ -543,11 +543,12 @@ async def handle_comm(self, comm, shutting_down=shutting_down): "Failed while closing connection to %r: %s", address, e ) - async def handle_stream(self, comm, extra=None, every_cycle=[]): + async def handle_stream( + self, comm, extra=None, every_cycle: Iterable[Callable[..., Awaitable]] = [] + ): extra = extra or {} logger.info("Starting established connection") - io_error = None closed = False try: while not closed: @@ -565,23 +566,17 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): closed = True break handler = self.stream_handlers[op] - if is_coroutine_function(handler): - self.loop.add_callback(handler, **merge(extra, msg)) - await gen.sleep(0) - else: - handler(**merge(extra, msg)) + self.loop.add_callback(handler, **merge(extra, msg)) + await gen.sleep(0) else: logger.error("odd message %s", msg) await asyncio.sleep(0) for func in every_cycle: - if is_coroutine_function(func): - self.loop.add_callback(func) - else: - func() + self.loop.add_callback(func) except (CommClosedError, EnvironmentError) as e: - io_error = e + pass except Exception as e: logger.exception(e) if LOG_PDB: diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 3aeec084df1..e1483d4a4d8 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -3,6 +3,7 @@ from contextlib import suppress import logging import threading +import typing import weakref from .core import CommClosedError @@ -10,13 +11,17 @@ from .utils import sync, TimeoutError, parse_timedelta from .protocol.serialize import to_serialize +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + from .worker import Worker + logger = logging.getLogger(__name__) class PubSubSchedulerExtension: """ Extend Dask's scheduler with routes to handle PubSub machinery """ - def __init__(self, scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.publishers = defaultdict(set) self.subscribers = defaultdict(set) @@ -35,7 +40,7 @@ def __init__(self, scheduler): self.scheduler.extensions["pubsub"] = self - def add_publisher(self, comm=None, name=None, worker=None): + async def add_publisher(self, comm=None, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) return { @@ -44,7 +49,7 @@ def add_publisher(self, comm=None, name=None, worker=None): and len(self.client_subscribers[name]) > 0, } - def add_subscriber(self, comm=None, name=None, worker=None, client=None): + async def add_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: logger.debug("Add worker subscriber: %s %s", name, worker) self.subscribers[name].add(worker) @@ -62,7 +67,7 @@ def add_subscriber(self, comm=None, name=None, worker=None, client=None): ) self.client_subscribers[name].add(client) - def remove_publisher(self, comm=None, name=None, worker=None): + async def remove_publisher(self, comm=None, name=None, worker=None): if worker in self.publishers[name]: logger.debug("Remove publisher: %s %s", name, worker) self.publishers[name].remove(worker) @@ -71,7 +76,7 @@ def remove_publisher(self, comm=None, name=None, worker=None): del self.subscribers[name] del self.publishers[name] - def remove_subscriber(self, comm=None, name=None, worker=None, client=None): + async def remove_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: logger.debug("Remove worker subscriber: %s %s", name, worker) self.subscribers[name].remove(worker) @@ -100,14 +105,14 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None): del self.subscribers[name] del self.publishers[name] - def handle_message(self, name=None, msg=None, worker=None, client=None): + async def handle_message(self, name=None, msg=None, worker=None, client=None): for c in list(self.client_subscribers[name]): try: self.scheduler.client_comms[c].send( {"op": "pubsub-msg", "name": name, "msg": msg} ) except (KeyError, CommClosedError): - self.remove_subscriber(name=name, client=c) + await self.remove_subscriber(name=name, client=c) if client: for sub in self.subscribers[name]: @@ -119,7 +124,7 @@ def handle_message(self, name=None, msg=None, worker=None, client=None): class PubSubWorkerExtension: """ Extend Dask's Worker with routes to handle PubSub machinery """ - def __init__(self, worker): + def __init__(self, worker: "Worker"): self.worker = worker self.worker.stream_handlers.update( { @@ -136,15 +141,15 @@ def __init__(self, worker): self.worker.extensions["pubsub"] = self # circular reference - def add_subscriber(self, name=None, address=None, **info): + async def add_subscriber(self, name=None, address=None, **info): for pub in self.publishers[name]: pub.subscribers[address] = info - def remove_subscriber(self, name=None, address=None): + async def remove_subscriber(self, name=None, address=None): for pub in self.publishers[name]: del pub.subscribers[address] - def publish_scheduler(self, name=None, publish=None): + async def publish_scheduler(self, name=None, publish=None): self.publish_to_scheduler[name] = publish async def handle_message(self, name=None, msg=None): @@ -384,6 +389,8 @@ def __init__(self, name, worker=None, client=None): pubsub = self.worker.extensions["pubsub"] elif self.client: pubsub = self.client.extensions["pubsub"] + else: + raise ValueError("Must include a worker or client") self.loop.add_callback(pubsub.subscribers[name].add, self) msg = {"op": "pubsub-add-subscriber", "name": self.name} diff --git a/distributed/queues.py b/distributed/queues.py index 15d6c8adca5..987ffd9e344 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,14 +1,17 @@ import asyncio from collections import defaultdict import logging +import typing import uuid from dask.utils import stringify from .client import Future, Client -from .utils import sync, thread_state +from .utils import sync, thread_state, parse_timedelta from .worker import get_client -from .utils import parse_timedelta + +if typing.TYPE_CHECKING: + from .scheduler import Scheduler logger = logging.getLogger(__name__) @@ -25,7 +28,7 @@ class QueueExtension: * queue_size """ - def __init__(self, scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.queues = dict() self.client_refcount = dict() @@ -54,7 +57,7 @@ def create(self, comm=None, name=None, client=None, maxsize=0): else: self.client_refcount[name] += 1 - def release(self, comm=None, name=None, client=None): + async def release(self, comm=None, name=None, client=None): if name not in self.queues: return @@ -78,7 +81,7 @@ async def put( record = {"type": "msgpack", "value": data} await asyncio.wait_for(self.queues[name].put(record), timeout=timeout) - def future_release(self, name=None, key=None, client=None): + async def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 if self.future_refcount[name, key] == 0: self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bbd8f01aeda..e481513ba05 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5,6 +5,7 @@ from contextlib import suppress from datetime import timedelta from functools import partial +from typing import Awaitable, Callable, Dict import inspect import itertools import json @@ -3039,7 +3040,7 @@ def __init__( self.event_counts = defaultdict(int) self.worker_plugins = [] - worker_handlers = { + worker_handlers: Dict[str, Callable[..., Awaitable]] = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, "release": self.handle_release_data, @@ -3052,7 +3053,7 @@ def __init__( "log-event": self.log_worker_event, } - client_handlers = { + client_handlers: Dict[str, Callable[..., Awaitable]] = { "update-graph": self.update_graph, "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, diff --git a/distributed/stealing.py b/distributed/stealing.py index 1e8428854ff..2e49115200c 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -2,6 +2,7 @@ import logging from math import log2 from time import time +import typing from tornado.ioloop import PeriodicCallback @@ -11,6 +12,9 @@ from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + from tlz import topk LATENCY = 10e-3 @@ -22,7 +26,7 @@ class WorkStealing(SchedulerPlugin): - def __init__(self, scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler # { level: { task states } } self.stealable_all = [set() for i in range(15)] diff --git a/distributed/utils.py b/distributed/utils.py index 30044740f64..a0c6460b235 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -27,6 +27,7 @@ import pkgutil import base64 import tblib.pickling_support +from typing import Awaitable, Callable, TypeVar import xml.etree.ElementTree try: @@ -289,6 +290,21 @@ def quiet(): return results +T = TypeVar("T") + + +def asyncify(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """ + Wrap a synchronous function in an async one + """ + + @functools.wraps(func) + async def wrapped(*args, **kwargs): + return func(*args, **kwargs) + + return wrapped + + def sync(loop, func, *args, callback_timeout=None, **kwargs): """ Run coroutine in loop running in separate thread. diff --git a/distributed/variable.py b/distributed/variable.py index db8da76e44c..4d9f675ff98 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -2,6 +2,7 @@ from collections import defaultdict from contextlib import suppress import logging +import typing import uuid from tlz import merge @@ -11,6 +12,9 @@ from .utils import log_errors, TimeoutError, parse_timedelta from .worker import get_client +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + logger = logging.getLogger(__name__) @@ -24,7 +28,7 @@ class VariableExtension: * variable-delete """ - def __init__(self, scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.variables = dict() self.waiting = defaultdict(set) diff --git a/distributed/worker.py b/distributed/worker.py index 4cf60b597e9..38fc5ba9509 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -15,6 +15,7 @@ import random import threading import sys +from typing import Awaitable, Callable, Dict import uuid import warnings import weakref @@ -47,6 +48,7 @@ from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import ( + asyncify, get_ip, typename, has_arg, @@ -676,12 +678,12 @@ def __init__( "plugin-add": self.plugin_add, } - stream_handlers = { + stream_handlers: Dict[str, Callable[..., Awaitable]] = { "close": self.close, - "compute-task": self.add_task, - "release-task": partial(self.release_key, report=False), - "delete-data": self.delete_data, - "steal-request": self.steal_request, + "compute-task": asyncify(self.add_task), + "release-task": asyncify(partial(self.release_key, report=False)), + "delete-data": asyncify(self.delete_data), + "steal-request": asyncify(self.steal_request), } super().__init__( @@ -985,7 +987,11 @@ async def heartbeat(self): async def handle_scheduler(self, comm): try: await self.handle_stream( - comm, every_cycle=[self.ensure_communicating, self.ensure_computing] + comm, + every_cycle=[ + asyncify(self.ensure_communicating), + self.ensure_computing, + ], ) except Exception as e: logger.exception(e)