diff --git a/hivemind/moe/server/connection_handler.py b/hivemind/moe/server/connection_handler.py index 2b21be73d..f6f0bcc85 100644 --- a/hivemind/moe/server/connection_handler.py +++ b/hivemind/moe/server/connection_handler.py @@ -34,7 +34,8 @@ def __init__( module_backends: Dict[str, ModuleBackend], *, balanced: bool = True, - shutdown_timeout: float = 3 + shutdown_timeout: float = 3, + start: bool = False, ): super().__init__() self.dht, self.module_backends = dht, module_backends @@ -44,6 +45,9 @@ def __init__( self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False) self.ready = MPFuture() + if start: + self.run_in_background(await_ready=True) + def run(self): torch.set_num_threads(1) loop = switch_to_uvloop() @@ -69,6 +73,18 @@ async def _run(): except KeyboardInterrupt: logger.debug("Caught KeyboardInterrupt, shutting down") + def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: + """ + Starts ConnectionHandler in a background process. If :await_ready:, this method will wait until + it is ready to process incoming requests or for :timeout: seconds max. + """ + self.start() + if await_ready: + self.wait_until_ready(timeout) + + def wait_until_ready(self, timeout: Optional[float] = None) -> None: + self.ready.result(timeout=timeout) + def shutdown(self): if self.is_alive(): self._outer_pipe.send("_shutdown") diff --git a/tests/test_connection_handler.py b/tests/test_connection_handler.py index 9a517f242..0f3220574 100644 --- a/tests/test_connection_handler.py +++ b/tests/test_connection_handler.py @@ -24,9 +24,7 @@ async def client_stub(): handler_dht = DHT(start=True) module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)} - handler = ConnectionHandler(handler_dht, module_backends) - handler.start() - assert handler.ready.exception() is None + handler = ConnectionHandler(handler_dht, module_backends, start=True) client_dht = DHT(start=True, client_mode=True, initial_peers=handler.dht.get_visible_maddrs()) client_stub = ConnectionHandler.get_stub(await client_dht.replicate_p2p(), handler.dht.peer_id) @@ -164,10 +162,8 @@ async def test_connection_handler_shutdown(): module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)} for _ in range(3): - handler = ConnectionHandler(handler_dht, module_backends, balanced=False) - handler.start() - # handler.ready would contain an exception if the previous handlers were not removed from hivemind.P2P - assert handler.ready.exception() is None + handler = ConnectionHandler(handler_dht, module_backends, balanced=False, start=True) + # The line above would raise an exception if the previous handlers were not removed from hivemind.P2P handler.shutdown() handler_dht.shutdown()