Skip to content

Commit

Permalink
Add ConnectionHandler(..., start=True) interface
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 17, 2022
1 parent a0f92b5 commit f050665
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
18 changes: 17 additions & 1 deletion hivemind/moe/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(
module_backends: Dict[str, ModuleBackend],
*,
balanced: bool = True,
shutdown_timeout: float = 3
shutdown_timeout: float = 3,
start: bool = False,
):
super().__init__()
self.dht, self.module_backends = dht, module_backends
Expand All @@ -44,6 +45,9 @@ def __init__(
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
self.ready = MPFuture()

if start:
self.run_in_background(await_ready=True)

def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
Expand All @@ -69,6 +73,18 @@ async def _run():
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")

def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Starts ConnectionHandler in a background process. If :await_ready:, this method will wait until
it is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
if await_ready:
self.wait_until_ready(timeout)

def wait_until_ready(self, timeout: Optional[float] = None) -> None:
self.ready.result(timeout=timeout)

def shutdown(self):
if self.is_alive():
self._outer_pipe.send("_shutdown")
Expand Down
10 changes: 3 additions & 7 deletions tests/test_connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@
async def client_stub():
handler_dht = DHT(start=True)
module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}
handler = ConnectionHandler(handler_dht, module_backends)
handler.start()
assert handler.ready.exception() is None
handler = ConnectionHandler(handler_dht, module_backends, start=True)

client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs())
client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id)
Expand Down Expand Up @@ -164,10 +162,8 @@ async def test_connection_handler_shutdown():
module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}

for _ in range(3):
handler = ConnectionHandler(handler_dht, module_backends, balanced=False)
handler.start()
# handler.ready would contain an exception if the previous handlers were not removed from hivemind.P2P
assert handler.ready.exception() is None
handler = ConnectionHandler(handler_dht, module_backends, balanced=False, start=True)
# The line above would raise an exception if the previous handlers were not removed from hivemind.P2P
handler.shutdown()

handler_dht.shutdown()
Expand Down

0 comments on commit f050665

Please sign in to comment.