diff --git a/distributed/comm/core.py b/distributed/comm/core.py index 7d16bda2336..74246a22d50 100644 --- a/distributed/comm/core.py +++ b/distributed/comm/core.py @@ -280,7 +280,7 @@ def time_left(): backoff_base = 0.01 attempt = 0 - + logger.debug("Establishing connection to %s", loc) # Prefer multiple small attempts than one long attempt. This should protect # primarily from DNS race conditions # gh3104, gh4176, gh4167 diff --git a/distributed/core.py b/distributed/core.py index 96dfe7b68a9..2734b9a0e37 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -443,6 +443,47 @@ def status(self, value: Status) -> None: raise TypeError(f"Expected Status; got {value!r}") self._status = value + @property + def incoming_comms_open(self) -> int: + """The number of total incoming connections listening to remote RPCs""" + return len(self._comms) + + @property + def incoming_comms_active(self) -> int: + """The number of connections currently handling a remote RPC""" + return len([c for c, op in self._comms.items() if op is not None]) + + @property + def outgoing_comms_open(self) -> int: + """The number of connections currently open and waiting for a remote RPC""" + return self.rpc.open + + @property + def outgoing_comms_active(self) -> int: + """The number of outgoing connections that are currently used to + execute a RPC""" + return self.rpc.active + + def get_connection_counters(self) -> dict[str, int]: + """A dict with various connection counters + + See also + -------- + Server.incoming_comms_open + Server.incoming_comms_active + Server.outgoing_comms_open + Server.outgoing_comms_active + """ + return { + attr: getattr(self, attr) + for attr in [ + "incoming_comms_open", + "incoming_comms_active", + "outgoing_comms_open", + "outgoing_comms_active", + ] + } + async def finished(self): """Wait until the server has finished""" await self._event_finished.wait() diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 35e8e748ed7..739cf557a64 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -1134,11 +1134,71 @@ async def long_handler(comm): while not server._comms: await asyncio.sleep(0.05) assert set(server._comms.values()) == {"wait"} + + assert server.incoming_comms_open == 1 + assert server.incoming_comms_active == 1 + + def validate_dict(server): + assert ( + server.get_connection_counters()["incoming_comms_open"] + == server.incoming_comms_open + ) + assert ( + server.get_connection_counters()["incoming_comms_active"] + == server.incoming_comms_active + ) + + assert ( + server.get_connection_counters()["outgoing_comms_open"] + == server.rpc.open + ) + assert ( + server.get_connection_counters()["outgoing_comms_active"] + == server.rpc.active + ) + + validate_dict(server) + assert await comm.read() == "done" assert set(server._comms.values()) == {None} + assert server.incoming_comms_open == 1 + assert server.incoming_comms_active == 0 + validate_dict(server) await comm.close() + while server._comms: await asyncio.sleep(0.01) + assert server.incoming_comms_active == 0 + assert server.incoming_comms_open == 0 + validate_dict(server) + + async with Server({}) as server2: + rpc_ = server2.rpc(server.address) + task = asyncio.create_task(rpc_.wait()) + while not server.incoming_comms_active: + await asyncio.sleep(0.1) + assert server.incoming_comms_active == 1 + assert server.incoming_comms_open == 1 + assert server.outgoing_comms_active == 0 + assert server.outgoing_comms_open == 0 + + assert server2.incoming_comms_active == 0 + assert server2.incoming_comms_open == 0 + assert server2.outgoing_comms_active == 1 + assert server2.outgoing_comms_open == 1 + validate_dict(server) + + await task + assert server.incoming_comms_active == 0 + assert server.incoming_comms_open == 1 + assert server.outgoing_comms_active == 0 + assert server.outgoing_comms_open == 0 + + assert server2.incoming_comms_active == 0 + assert server2.incoming_comms_open == 0 + assert server2.outgoing_comms_active == 0 + assert server2.outgoing_comms_open == 1 + validate_dict(server) @pytest.mark.parametrize("close_via_rpc", [True, False])