From 36e026ee39f2d553d800f52b23677366ca1fc37f Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 12:47:06 -0800 Subject: [PATCH 1/6] Type stream handlers as async-only --- distributed/core.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 15205f4f72c..ff339fc872a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -7,6 +7,7 @@ import logging import threading import traceback +from typing import Callable, Dict, Awaitable import uuid import weakref import warnings @@ -127,7 +128,7 @@ def __init__( self, handlers, blocked_handlers=None, - stream_handlers=None, + stream_handlers: Dict[str, Callable[..., Awaitable]] = None, connection_limit=512, deserialize=True, serializers=None, @@ -147,7 +148,7 @@ def __init__( "distributed.%s.blocked-handlers" % type(self).__name__.lower(), [] ) self.blocked_handlers = blocked_handlers - self.stream_handlers = {} + self.stream_handlers: Dict[str, Callable[..., Awaitable]] = {} self.stream_handlers.update(stream_handlers or {}) self.id = type(self).__name__ + "-" + str(uuid.uuid4()) @@ -547,7 +548,6 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): extra = extra or {} logger.info("Starting established connection") - io_error = None closed = False try: while not closed: @@ -565,11 +565,8 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): closed = True break handler = self.stream_handlers[op] - if is_coroutine_function(handler): - self.loop.add_callback(handler, **merge(extra, msg)) - await gen.sleep(0) - else: - handler(**merge(extra, msg)) + self.loop.add_callback(handler, **merge(extra, msg)) + await gen.sleep(0) else: logger.error("odd message %s", msg) await asyncio.sleep(0) @@ -581,7 +578,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): func() except (CommClosedError, EnvironmentError) as e: - io_error = e + pass except Exception as e: logger.exception(e) if LOG_PDB: From d682c136b1854d401d052ecce75d2a80ce6571c1 Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 13:48:56 -0800 Subject: [PATCH 2/6] Asyncify some stream handlers --- distributed/pubsub.py | 26 +++++++++++++++----------- distributed/queues.py | 10 +++++----- distributed/scheduler.py | 5 +++-- distributed/stealing.py | 3 ++- distributed/variable.py | 3 ++- distributed/worker.py | 11 ++++++----- 6 files changed, 33 insertions(+), 25 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 3aeec084df1..e357d0173fd 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -9,6 +9,8 @@ from .metrics import time from .utils import sync, TimeoutError, parse_timedelta from .protocol.serialize import to_serialize +from .scheduler import Scheduler +from .worker import Worker logger = logging.getLogger(__name__) @@ -16,7 +18,7 @@ class PubSubSchedulerExtension: """ Extend Dask's scheduler with routes to handle PubSub machinery """ - def __init__(self, scheduler): + def __init__(self, scheduler: Scheduler): self.scheduler = scheduler self.publishers = defaultdict(set) self.subscribers = defaultdict(set) @@ -35,7 +37,7 @@ def __init__(self, scheduler): self.scheduler.extensions["pubsub"] = self - def add_publisher(self, comm=None, name=None, worker=None): + async def add_publisher(self, comm=None, name=None, worker=None): logger.debug("Add publisher: %s %s", name, worker) self.publishers[name].add(worker) return { @@ -44,7 +46,7 @@ def add_publisher(self, comm=None, name=None, worker=None): and len(self.client_subscribers[name]) > 0, } - def add_subscriber(self, comm=None, name=None, worker=None, client=None): + async def add_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: logger.debug("Add worker subscriber: %s %s", name, worker) self.subscribers[name].add(worker) @@ -62,7 +64,7 @@ def add_subscriber(self, comm=None, name=None, worker=None, client=None): ) self.client_subscribers[name].add(client) - def remove_publisher(self, comm=None, name=None, worker=None): + async def remove_publisher(self, comm=None, name=None, worker=None): if worker in self.publishers[name]: logger.debug("Remove publisher: %s %s", name, worker) self.publishers[name].remove(worker) @@ -71,7 +73,7 @@ def remove_publisher(self, comm=None, name=None, worker=None): del self.subscribers[name] del self.publishers[name] - def remove_subscriber(self, comm=None, name=None, worker=None, client=None): + async def remove_subscriber(self, comm=None, name=None, worker=None, client=None): if worker: logger.debug("Remove worker subscriber: %s %s", name, worker) self.subscribers[name].remove(worker) @@ -100,14 +102,14 @@ def remove_subscriber(self, comm=None, name=None, worker=None, client=None): del self.subscribers[name] del self.publishers[name] - def handle_message(self, name=None, msg=None, worker=None, client=None): + async def handle_message(self, name=None, msg=None, worker=None, client=None): for c in list(self.client_subscribers[name]): try: self.scheduler.client_comms[c].send( {"op": "pubsub-msg", "name": name, "msg": msg} ) except (KeyError, CommClosedError): - self.remove_subscriber(name=name, client=c) + await self.remove_subscriber(name=name, client=c) if client: for sub in self.subscribers[name]: @@ -119,7 +121,7 @@ def handle_message(self, name=None, msg=None, worker=None, client=None): class PubSubWorkerExtension: """ Extend Dask's Worker with routes to handle PubSub machinery """ - def __init__(self, worker): + def __init__(self, worker: Worker): self.worker = worker self.worker.stream_handlers.update( { @@ -136,15 +138,15 @@ def __init__(self, worker): self.worker.extensions["pubsub"] = self # circular reference - def add_subscriber(self, name=None, address=None, **info): + async def add_subscriber(self, name=None, address=None, **info): for pub in self.publishers[name]: pub.subscribers[address] = info - def remove_subscriber(self, name=None, address=None): + async def remove_subscriber(self, name=None, address=None): for pub in self.publishers[name]: del pub.subscribers[address] - def publish_scheduler(self, name=None, publish=None): + async def publish_scheduler(self, name=None, publish=None): self.publish_to_scheduler[name] = publish async def handle_message(self, name=None, msg=None): @@ -384,6 +386,8 @@ def __init__(self, name, worker=None, client=None): pubsub = self.worker.extensions["pubsub"] elif self.client: pubsub = self.client.extensions["pubsub"] + else: + raise ValueError("Must include a worker or client") self.loop.add_callback(pubsub.subscribers[name].add, self) msg = {"op": "pubsub-add-subscriber", "name": self.name} diff --git a/distributed/queues.py b/distributed/queues.py index 15d6c8adca5..7369606c74d 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -6,9 +6,9 @@ from dask.utils import stringify from .client import Future, Client -from .utils import sync, thread_state +from .scheduler import Scheduler +from .utils import sync, thread_state, parse_timedelta from .worker import get_client -from .utils import parse_timedelta logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class QueueExtension: * queue_size """ - def __init__(self, scheduler): + def __init__(self, scheduler: Scheduler): self.scheduler = scheduler self.queues = dict() self.client_refcount = dict() @@ -54,7 +54,7 @@ def create(self, comm=None, name=None, client=None, maxsize=0): else: self.client_refcount[name] += 1 - def release(self, comm=None, name=None, client=None): + async def release(self, comm=None, name=None, client=None): if name not in self.queues: return @@ -78,7 +78,7 @@ async def put( record = {"type": "msgpack", "value": data} await asyncio.wait_for(self.queues[name].put(record), timeout=timeout) - def future_release(self, name=None, key=None, client=None): + async def future_release(self, name=None, key=None, client=None): self.future_refcount[name, key] -= 1 if self.future_refcount[name, key] == 0: self.scheduler.client_releases_keys(keys=[key], client="queue-%s" % name) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index bbd8f01aeda..e481513ba05 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5,6 +5,7 @@ from contextlib import suppress from datetime import timedelta from functools import partial +from typing import Awaitable, Callable, Dict import inspect import itertools import json @@ -3039,7 +3040,7 @@ def __init__( self.event_counts = defaultdict(int) self.worker_plugins = [] - worker_handlers = { + worker_handlers: Dict[str, Callable[..., Awaitable]] = { "task-finished": self.handle_task_finished, "task-erred": self.handle_task_erred, "release": self.handle_release_data, @@ -3052,7 +3053,7 @@ def __init__( "log-event": self.log_worker_event, } - client_handlers = { + client_handlers: Dict[str, Callable[..., Awaitable]] = { "update-graph": self.update_graph, "update-graph-hlg": self.update_graph_hlg, "client-desires-keys": self.client_desires_keys, diff --git a/distributed/stealing.py b/distributed/stealing.py index 1e8428854ff..20121ed6dc9 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -10,6 +10,7 @@ from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta +from .scheduler import Scheduler from tlz import topk @@ -22,7 +23,7 @@ class WorkStealing(SchedulerPlugin): - def __init__(self, scheduler): + def __init__(self, scheduler: Scheduler): self.scheduler = scheduler # { level: { task states } } self.stealable_all = [set() for i in range(15)] diff --git a/distributed/variable.py b/distributed/variable.py index db8da76e44c..2c7f78bd7b5 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -9,6 +9,7 @@ from dask.utils import stringify from .client import Future, Client from .utils import log_errors, TimeoutError, parse_timedelta +from .scheduler import Scheduler from .worker import get_client logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ class VariableExtension: * variable-delete """ - def __init__(self, scheduler): + def __init__(self, scheduler: Scheduler): self.scheduler = scheduler self.variables = dict() self.waiting = defaultdict(set) diff --git a/distributed/worker.py b/distributed/worker.py index 4cf60b597e9..c38b6333e77 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -15,6 +15,7 @@ import random import threading import sys +from typing import Awaitable, Callable, Dict import uuid import warnings import weakref @@ -676,7 +677,7 @@ def __init__( "plugin-add": self.plugin_add, } - stream_handlers = { + stream_handlers: Dict[str, Callable[..., Awaitable]] = { "close": self.close, "compute-task": self.add_task, "release-task": partial(self.release_key, report=False), @@ -1410,7 +1411,7 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info - def delete_data(self, comm=None, keys=None, report=True): + async def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): self.log.append((key, "delete")) @@ -1437,7 +1438,7 @@ async def set_resources(self, **resources): # Task Management # ################### - def add_task( + async def add_task( self, key, function=None, @@ -2264,7 +2265,7 @@ def update_who_has(self, who_has): pdb.set_trace() raise - def steal_request(self, key): + async def steal_request(self, key): # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end @@ -2290,7 +2291,7 @@ def steal_request(self, key): if self.validate: assert ts.runspec is None - def release_key(self, key, cause=None, reason=None, report=True): + async def release_key(self, key, cause=None, reason=None, report=True): try: if self.validate: assert isinstance(key, str) From d80cf8dd9d8a222e8bdf4dd424c1b99cebf4cd64 Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 16:34:35 -0800 Subject: [PATCH 3/6] Fix circular imports -- a lot of work just to get a type annotation. --- distributed/pubsub.py | 8 ++++++-- distributed/queues.py | 6 +++++- distributed/stealing.py | 6 +++++- distributed/variable.py | 6 +++++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index e357d0173fd..0575b5a5d60 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -1,16 +1,20 @@ +from __future__ import annotations import asyncio from collections import defaultdict, deque from contextlib import suppress import logging import threading +import typing import weakref from .core import CommClosedError from .metrics import time from .utils import sync, TimeoutError, parse_timedelta from .protocol.serialize import to_serialize -from .scheduler import Scheduler -from .worker import Worker + +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + from .worker import Worker logger = logging.getLogger(__name__) diff --git a/distributed/queues.py b/distributed/queues.py index 7369606c74d..d06d6c42e37 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,15 +1,19 @@ +from __future__ import annotations import asyncio from collections import defaultdict import logging +import typing import uuid from dask.utils import stringify from .client import Future, Client -from .scheduler import Scheduler from .utils import sync, thread_state, parse_timedelta from .worker import get_client +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + logger = logging.getLogger(__name__) diff --git a/distributed/stealing.py b/distributed/stealing.py index 20121ed6dc9..a8132bb0d7b 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -1,7 +1,9 @@ +from __future__ import annotations from collections import defaultdict, deque import logging from math import log2 from time import time +import typing from tornado.ioloop import PeriodicCallback @@ -10,7 +12,9 @@ from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta -from .scheduler import Scheduler + +if typing.TYPE_CHECKING: + from .scheduler import Scheduler from tlz import topk diff --git a/distributed/variable.py b/distributed/variable.py index 2c7f78bd7b5..86d183e6052 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -1,7 +1,9 @@ +from __future__ import annotations import asyncio from collections import defaultdict from contextlib import suppress import logging +import typing import uuid from tlz import merge @@ -9,9 +11,11 @@ from dask.utils import stringify from .client import Future, Client from .utils import log_errors, TimeoutError, parse_timedelta -from .scheduler import Scheduler from .worker import get_client +if typing.TYPE_CHECKING: + from .scheduler import Scheduler + logger = logging.getLogger(__name__) From 6b90ee81870fb37de8ece581f101f2f44ff2adff Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 17:24:14 -0800 Subject: [PATCH 4/6] Roll back some API changes, instead introduce asyncify util function. --- distributed/utils.py | 16 ++++++++++++++++ distributed/worker.py | 17 +++++++++-------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index 30044740f64..a0c6460b235 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -27,6 +27,7 @@ import pkgutil import base64 import tblib.pickling_support +from typing import Awaitable, Callable, TypeVar import xml.etree.ElementTree try: @@ -289,6 +290,21 @@ def quiet(): return results +T = TypeVar("T") + + +def asyncify(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """ + Wrap a synchronous function in an async one + """ + + @functools.wraps(func) + async def wrapped(*args, **kwargs): + return func(*args, **kwargs) + + return wrapped + + def sync(loop, func, *args, callback_timeout=None, **kwargs): """ Run coroutine in loop running in separate thread. diff --git a/distributed/worker.py b/distributed/worker.py index c38b6333e77..c9712fe2962 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -48,6 +48,7 @@ from .sizeof import safe_sizeof as sizeof from .threadpoolexecutor import ThreadPoolExecutor, secede as tpe_secede from .utils import ( + asyncify, get_ip, typename, has_arg, @@ -679,10 +680,10 @@ def __init__( stream_handlers: Dict[str, Callable[..., Awaitable]] = { "close": self.close, - "compute-task": self.add_task, - "release-task": partial(self.release_key, report=False), - "delete-data": self.delete_data, - "steal-request": self.steal_request, + "compute-task": asyncify(self.add_task), + "release-task": asyncify(partial(self.release_key, report=False)), + "delete-data": asyncify(self.delete_data), + "steal-request": asyncify(self.steal_request), } super().__init__( @@ -1411,7 +1412,7 @@ def update_data(self, comm=None, data=None, report=True, serializers=None): info = {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} return info - async def delete_data(self, comm=None, keys=None, report=True): + def delete_data(self, comm=None, keys=None, report=True): if keys: for key in list(keys): self.log.append((key, "delete")) @@ -1438,7 +1439,7 @@ async def set_resources(self, **resources): # Task Management # ################### - async def add_task( + def add_task( self, key, function=None, @@ -2265,7 +2266,7 @@ def update_who_has(self, who_has): pdb.set_trace() raise - async def steal_request(self, key): + def steal_request(self, key): # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end @@ -2291,7 +2292,7 @@ async def steal_request(self, key): if self.validate: assert ts.runspec is None - async def release_key(self, key, cause=None, reason=None, report=True): + def release_key(self, key, cause=None, reason=None, report=True): try: if self.validate: assert isinstance(key, str) From bebc06ea7a665d7db3d5d578b942367e0e97484d Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 17:43:12 -0800 Subject: [PATCH 5/6] Also ensure that every_cycle is async --- distributed/core.py | 12 +++++------- distributed/worker.py | 6 +++++- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index ff339fc872a..dcc2997d13b 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -7,7 +7,7 @@ import logging import threading import traceback -from typing import Callable, Dict, Awaitable +from typing import Awaitable, Callable, Dict, Iterable import uuid import weakref import warnings @@ -30,7 +30,6 @@ from . import profile from .system_monitor import SystemMonitor from .utils import ( - is_coroutine_function, get_traceback, truncate_exception, shutting_down, @@ -544,7 +543,9 @@ async def handle_comm(self, comm, shutting_down=shutting_down): "Failed while closing connection to %r: %s", address, e ) - async def handle_stream(self, comm, extra=None, every_cycle=[]): + async def handle_stream( + self, comm, extra=None, every_cycle: Iterable[Callable[..., Awaitable]] = [] + ): extra = extra or {} logger.info("Starting established connection") @@ -572,10 +573,7 @@ async def handle_stream(self, comm, extra=None, every_cycle=[]): await asyncio.sleep(0) for func in every_cycle: - if is_coroutine_function(func): - self.loop.add_callback(func) - else: - func() + self.loop.add_callback(func) except (CommClosedError, EnvironmentError) as e: pass diff --git a/distributed/worker.py b/distributed/worker.py index c9712fe2962..38fc5ba9509 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -987,7 +987,11 @@ async def heartbeat(self): async def handle_scheduler(self, comm): try: await self.handle_stream( - comm, every_cycle=[self.ensure_communicating, self.ensure_computing] + comm, + every_cycle=[ + asyncify(self.ensure_communicating), + self.ensure_computing, + ], ) except Exception as e: logger.exception(e) From 4f1944b9f354eb4496a20003e6493e867df16f57 Mon Sep 17 00:00:00 2001 From: Ian Rose Date: Fri, 29 Jan 2021 17:49:40 -0800 Subject: [PATCH 6/6] Python 3.6 support --- distributed/pubsub.py | 5 ++--- distributed/queues.py | 3 +-- distributed/stealing.py | 3 +-- distributed/variable.py | 3 +-- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/distributed/pubsub.py b/distributed/pubsub.py index 0575b5a5d60..e1483d4a4d8 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -1,4 +1,3 @@ -from __future__ import annotations import asyncio from collections import defaultdict, deque from contextlib import suppress @@ -22,7 +21,7 @@ class PubSubSchedulerExtension: """ Extend Dask's scheduler with routes to handle PubSub machinery """ - def __init__(self, scheduler: Scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.publishers = defaultdict(set) self.subscribers = defaultdict(set) @@ -125,7 +124,7 @@ async def handle_message(self, name=None, msg=None, worker=None, client=None): class PubSubWorkerExtension: """ Extend Dask's Worker with routes to handle PubSub machinery """ - def __init__(self, worker: Worker): + def __init__(self, worker: "Worker"): self.worker = worker self.worker.stream_handlers.update( { diff --git a/distributed/queues.py b/distributed/queues.py index d06d6c42e37..987ffd9e344 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -1,4 +1,3 @@ -from __future__ import annotations import asyncio from collections import defaultdict import logging @@ -29,7 +28,7 @@ class QueueExtension: * queue_size """ - def __init__(self, scheduler: Scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.queues = dict() self.client_refcount = dict() diff --git a/distributed/stealing.py b/distributed/stealing.py index a8132bb0d7b..2e49115200c 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -1,4 +1,3 @@ -from __future__ import annotations from collections import defaultdict, deque import logging from math import log2 @@ -27,7 +26,7 @@ class WorkStealing(SchedulerPlugin): - def __init__(self, scheduler: Scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler # { level: { task states } } self.stealable_all = [set() for i in range(15)] diff --git a/distributed/variable.py b/distributed/variable.py index 86d183e6052..4d9f675ff98 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -1,4 +1,3 @@ -from __future__ import annotations import asyncio from collections import defaultdict from contextlib import suppress @@ -29,7 +28,7 @@ class VariableExtension: * variable-delete """ - def __init__(self, scheduler: Scheduler): + def __init__(self, scheduler: "Scheduler"): self.scheduler = scheduler self.variables = dict() self.waiting = defaultdict(set)