From 0e2be4a2b13e0f8d98ebc4ae31632f0e3ca82c47 Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Mon, 15 Apr 2024 11:56:35 +0200 Subject: [PATCH] Fix race condition for published futures with annotations (#8577) --- distributed/client.py | 11 ++++++++--- distributed/publish.py | 14 ++++++++++++++ distributed/tests/test_publish.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 36617dc0c34..df976d0e058 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2650,18 +2650,23 @@ def retry(self, futures, asynchronous=None): @log_errors async def _publish_dataset(self, *args, name=None, override=False, **kwargs): coroutines = [] + uid = uuid.uuid4().hex + self._send_to_scheduler({"op": "publish_flush_batched_send", "uid": uid}) def add_coro(name, data): keys = [f.key for f in futures_of(data)] - coroutines.append( - self.scheduler.publish_put( + + async def _(): + await self.scheduler.publish_wait_flush(uid=uid) + await self.scheduler.publish_put( keys=keys, name=name, data=to_serialize(data), override=override, client=self.id, ) - ) + + coroutines.append(_()) if name: if len(args) == 0: diff --git a/distributed/publish.py b/distributed/publish.py index 3d97e9d1f00..e7887b8dc96 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +from collections import defaultdict from collections.abc import MutableMapping from dask.utils import stringify @@ -25,9 +27,21 @@ def __init__(self, scheduler): "publish_put": self.put, "publish_get": self.get, "publish_delete": self.delete, + "publish_wait_flush": self.flush_wait, + } + stream_handlers = { + "publish_flush_batched_send": self.flush_receive, } self.scheduler.handlers.update(handlers) + self.scheduler.stream_handlers.update(stream_handlers) + self._flush_received = defaultdict(asyncio.Event) + + def flush_receive(self, uid, **kwargs): + self._flush_received[uid].set() + + async def flush_wait(self, uid): + await self._flush_received[uid].wait() @log_errors def put(self, keys=None, data=None, name=None, override=False, client=None): diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index f1107b57619..3669030ccaf 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -11,6 +11,7 @@ from distributed.metrics import time from distributed.protocol import Serialized from distributed.utils_test import gen_cluster, inc +from distributed.worker import get_worker @gen_cluster() @@ -301,3 +302,30 @@ async def test_deserialize_client(c, s, a, b): from distributed.client import _current_client assert _current_client.get() is c + + +@gen_cluster(client=True, worker_kwargs={"resources": {"A": 1}}) +async def test_publish_submit_ordering(c, s, a, b): + RESOURCES = {"A": 1} + + def _retrieve_annotations(): + worker = get_worker() + task = worker.state.tasks.get(worker.get_current_task()) + return task.annotations + + # If publish does not take the same comm channel as the submit, it can + # happen that the publish message reaches the scheduler before the submit + # such that the state of the published future is not the one that has been + # requested from the submit. Particularly, this lets us drop annotations + # The current implementation does in fact not use the same channel due to + # serialization issue (including Futures in BatchedSend appends them to the + # "recent messages" log which screws with the refcounting) but ensure that + # all queued up messages are flushed and received by the schduler befure + # publishing + future = c.submit(_retrieve_annotations, resources=RESOURCES, pure=False) + + await c.publish_dataset(future, name="foo") + assert await c.list_datasets() == ("foo",) + + result = await future.result() + assert result == {"resources": RESOURCES}