From 60f088683d41a45da63c7ac1d7331d6fd28f147e Mon Sep 17 00:00:00 2001 From: Florian Jetter Date: Tue, 24 May 2022 11:28:54 +0200 Subject: [PATCH] Server close faster (#6415) Co-authored-by: Matthew Rocklin --- distributed/core.py | 55 +++++++++++++++++++------------- distributed/deploy/spec.py | 6 +++- distributed/tests/test_client.py | 2 +- distributed/tests/test_core.py | 31 +++++++++--------- 4 files changed, 54 insertions(+), 40 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index d54592ef5a1..2a2a12521ff 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -187,7 +187,7 @@ def __init__( self.monitor = SystemMonitor() self.counters = None self.digests = None - self._ongoing_coroutines = weakref.WeakSet() + self._ongoing_coroutines = set() self._event_finished = asyncio.Event() self.listeners = [] @@ -508,7 +508,7 @@ async def handle_comm(self, comm): await self try: - while True: + while not self.__stopped: try: msg = await comm.read() logger.debug("Message from %r: %s", address, msg) @@ -579,10 +579,17 @@ async def handle_comm(self, comm): result = handler(comm, **msg) else: result = handler(**msg) - if inspect.isawaitable(result): - result = asyncio.ensure_future(result) + if inspect.iscoroutine(result): + result = asyncio.create_task( + result, name=f"handle-comm-{address}-{op}" + ) self._ongoing_coroutines.add(result) + result.add_done_callback(self._ongoing_coroutines.remove) result = await result + elif inspect.isawaitable(result): + raise RuntimeError( + f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}" + ) except CommClosedError: if self.status == Status.running: logger.info("Lost connection to %r", address, exc_info=True) @@ -666,34 +673,36 @@ async def handle_stream(self, comm, extra=None): await comm.close() assert comm.closed() - @gen.coroutine - def close(self): + async def close(self, timeout=None): for pc in self.periodic_callbacks.values(): pc.stop() if not self.__stopped: self.__stopped = True + _stops = set() for listener in self.listeners: future = listener.stop() if inspect.isawaitable(future): - yield future - for i in range(20): - # If there are still handlers running at this point, give them a - # second to finish gracefully themselves, otherwise... - if any(self._comms.values()): - yield asyncio.sleep(0.05) - else: - break - yield self.rpc.close() - yield [comm.close() for comm in list(self._comms)] # then forcefully close - for cb in self._ongoing_coroutines: - cb.cancel() - for i in range(10): - if all(c.cancelled() for c in self._ongoing_coroutines): - break - else: - yield asyncio.sleep(0.01) + _stops.add(future) + await asyncio.gather(*_stops) + + def _ongoing_tasks(): + return ( + t for t in self._ongoing_coroutines if t is not asyncio.current_task() + ) + + # TODO: Deal with exceptions + try: + # Give the handlers a bit of time to finish gracefully + await asyncio.wait_for( + asyncio.gather(*_ongoing_tasks(), return_exceptions=True), 1 + ) + except asyncio.TimeoutError: + # the timeout on gather should've cancelled all the tasks + await asyncio.gather(*_ongoing_tasks(), return_exceptions=True) + await self.rpc.close() + await asyncio.gather(*[comm.close() for comm in list(self._comms)]) self._event_finished.set() diff --git a/distributed/deploy/spec.py b/distributed/deploy/spec.py index 778f0bf19be..f24ce7b05d0 100644 --- a/distributed/deploy/spec.py +++ b/distributed/deploy/spec.py @@ -418,7 +418,11 @@ async def _close(self): await self.scheduler.close() for w in self._created: - assert w.status in {Status.closed, Status.failed}, w.status + assert w.status in { + Status.closing, + Status.closed, + Status.failed, + }, w.status if hasattr(self, "_old_logging_level"): silence_logging(self._old_logging_level) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 8999534b34d..fe2fed711a5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -6134,7 +6134,7 @@ async def test_shutdown(): await c.shutdown() assert s.status == Status.closed - assert w.status == Status.closed + assert w.status in {Status.closed, Status.closing} @gen_test() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 74e38f8671b..7422181d024 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -978,22 +978,22 @@ async def long_handler(comm): await asyncio.sleep(0.01) +@pytest.mark.parametrize("close_via_rpc", [True, False]) @gen_test() -async def test_close_fast_without_active_handlers(): - async def very_fast(comm): - return "done" +async def test_close_fast_without_active_handlers(close_via_rpc): - server = await Server({"do_stuff": very_fast}) + server = await Server({}) + server.handlers["terminate"] = server.close await server.listen(0) assert server._comms == {} - comm = await connect(server.address) - await comm.write({"op": "do_stuff"}) - while not server._comms: - await asyncio.sleep(0.05) - fut = server.close() - - await asyncio.wait_for(fut, 0.1) + if not close_via_rpc: + fut = server.close() + await asyncio.wait_for(fut, 0.5) + else: + async with rpc(server.address) as _rpc: + fut = _rpc.terminate(reply=False) + await asyncio.wait_for(fut, 0.5) @gen_test() @@ -1010,13 +1010,14 @@ async def long_handler(comm, delay=10): await comm.write({"op": "wait"}) while not server._comms: await asyncio.sleep(0.05) - fut = server.close() + task = asyncio.create_task(server.close()) + wait_for_close = asyncio.Event() + task.add_done_callback(lambda _: wait_for_close.set) # since the handler is running for a while, the close will not immediately # go through. We'll give the comm about a second to close itself with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(fut, 0.5) - await comm.close() - await server.close() + await asyncio.wait_for(wait_for_close.wait(), 0.5) + await task def test_expects_comm():