Skip to content

Commit

Permalink
Server close faster (#6415)
Browse files Browse the repository at this point in the history
Co-authored-by: Matthew Rocklin <[email protected]>
  • Loading branch information
fjetter and mrocklin authored May 24, 2022
1 parent 7665eaa commit 60f0886
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 40 deletions.
55 changes: 32 additions & 23 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __init__(
self.monitor = SystemMonitor()
self.counters = None
self.digests = None
self._ongoing_coroutines = weakref.WeakSet()
self._ongoing_coroutines = set()
self._event_finished = asyncio.Event()

self.listeners = []
Expand Down Expand Up @@ -508,7 +508,7 @@ async def handle_comm(self, comm):

await self
try:
while True:
while not self.__stopped:
try:
msg = await comm.read()
logger.debug("Message from %r: %s", address, msg)
Expand Down Expand Up @@ -579,10 +579,17 @@ async def handle_comm(self, comm):
result = handler(comm, **msg)
else:
result = handler(**msg)
if inspect.isawaitable(result):
result = asyncio.ensure_future(result)
if inspect.iscoroutine(result):
result = asyncio.create_task(
result, name=f"handle-comm-{address}-{op}"
)
self._ongoing_coroutines.add(result)
result.add_done_callback(self._ongoing_coroutines.remove)
result = await result
elif inspect.isawaitable(result):
raise RuntimeError(
f"Comm handler returned unknown awaitable. Expected coroutine, instead got {type(result)}"
)
except CommClosedError:
if self.status == Status.running:
logger.info("Lost connection to %r", address, exc_info=True)
Expand Down Expand Up @@ -666,34 +673,36 @@ async def handle_stream(self, comm, extra=None):
await comm.close()
assert comm.closed()

@gen.coroutine
def close(self):
async def close(self, timeout=None):
for pc in self.periodic_callbacks.values():
pc.stop()

if not self.__stopped:
self.__stopped = True
_stops = set()
for listener in self.listeners:
future = listener.stop()
if inspect.isawaitable(future):
yield future
for i in range(20):
# If there are still handlers running at this point, give them a
# second to finish gracefully themselves, otherwise...
if any(self._comms.values()):
yield asyncio.sleep(0.05)
else:
break
yield self.rpc.close()
yield [comm.close() for comm in list(self._comms)] # then forcefully close
for cb in self._ongoing_coroutines:
cb.cancel()
for i in range(10):
if all(c.cancelled() for c in self._ongoing_coroutines):
break
else:
yield asyncio.sleep(0.01)
_stops.add(future)
await asyncio.gather(*_stops)

def _ongoing_tasks():
return (
t for t in self._ongoing_coroutines if t is not asyncio.current_task()
)

# TODO: Deal with exceptions
try:
# Give the handlers a bit of time to finish gracefully
await asyncio.wait_for(
asyncio.gather(*_ongoing_tasks(), return_exceptions=True), 1
)
except asyncio.TimeoutError:
# the timeout on gather should've cancelled all the tasks
await asyncio.gather(*_ongoing_tasks(), return_exceptions=True)

await self.rpc.close()
await asyncio.gather(*[comm.close() for comm in list(self._comms)])
self._event_finished.set()


Expand Down
6 changes: 5 additions & 1 deletion distributed/deploy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,11 @@ async def _close(self):

await self.scheduler.close()
for w in self._created:
assert w.status in {Status.closed, Status.failed}, w.status
assert w.status in {
Status.closing,
Status.closed,
Status.failed,
}, w.status

if hasattr(self, "_old_logging_level"):
silence_logging(self._old_logging_level)
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6134,7 +6134,7 @@ async def test_shutdown():
await c.shutdown()

assert s.status == Status.closed
assert w.status == Status.closed
assert w.status in {Status.closed, Status.closing}


@gen_test()
Expand Down
31 changes: 16 additions & 15 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,22 +978,22 @@ async def long_handler(comm):
await asyncio.sleep(0.01)


@pytest.mark.parametrize("close_via_rpc", [True, False])
@gen_test()
async def test_close_fast_without_active_handlers():
async def very_fast(comm):
return "done"
async def test_close_fast_without_active_handlers(close_via_rpc):

server = await Server({"do_stuff": very_fast})
server = await Server({})
server.handlers["terminate"] = server.close
await server.listen(0)
assert server._comms == {}

comm = await connect(server.address)
await comm.write({"op": "do_stuff"})
while not server._comms:
await asyncio.sleep(0.05)
fut = server.close()

await asyncio.wait_for(fut, 0.1)
if not close_via_rpc:
fut = server.close()
await asyncio.wait_for(fut, 0.5)
else:
async with rpc(server.address) as _rpc:
fut = _rpc.terminate(reply=False)
await asyncio.wait_for(fut, 0.5)


@gen_test()
Expand All @@ -1010,13 +1010,14 @@ async def long_handler(comm, delay=10):
await comm.write({"op": "wait"})
while not server._comms:
await asyncio.sleep(0.05)
fut = server.close()
task = asyncio.create_task(server.close())
wait_for_close = asyncio.Event()
task.add_done_callback(lambda _: wait_for_close.set)
# since the handler is running for a while, the close will not immediately
# go through. We'll give the comm about a second to close itself
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(fut, 0.5)
await comm.close()
await server.close()
await asyncio.wait_for(wait_for_close.wait(), 0.5)
await task


def test_expects_comm():
Expand Down

0 comments on commit 60f0886

Please sign in to comment.