diff --git a/hivemind/moe/server/connection_handler.py b/hivemind/moe/server/connection_handler.py index d00827689..50a79c37a 100644 --- a/hivemind/moe/server/connection_handler.py +++ b/hivemind/moe/server/connection_handler.py @@ -28,36 +28,52 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase): :param module_backends: a dict [UID -> ModuleBackend] with all active experts """ - def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]): + def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend], *, shutdown_timeout: float = 3): super().__init__() self.dht, self.module_backends = dht, module_backends + self.shutdown_timeout = shutdown_timeout self._p2p: Optional[P2P] = None + self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False) self.ready = MPFuture() def run(self): torch.set_num_threads(1) loop = switch_to_uvloop() + stop = asyncio.Event() + loop.add_reader(self._inner_pipe.fileno(), stop.set) async def _run(): try: self._p2p = await self.dht.replicate_p2p() await self.add_p2p_handlers(self._p2p, balanced=True) - - # wait forever - await asyncio.Future() - + self.ready.set_result(None) except Exception as e: + logger.error("ConnectionHandler failed to start:", exc_info=True) self.ready.set_exception(e) - return - self.ready.set_result(None) + try: + await stop.wait() + finally: + await self.remove_p2p_handlers(self._p2p) try: loop.run_until_complete(_run()) except KeyboardInterrupt: logger.debug("Caught KeyboardInterrupt, shutting down") + def shutdown(self): + if self.is_alive(): + self._outer_pipe.send("_shutdown") + self.join(self.shutdown_timeout) + if self.is_alive(): + logger.warning( + "ConnectionHandler did not shut down within the grace period; terminating it the hard way" + ) + self.terminate() + else: + logger.warning("ConnectionHandler shutdown had no effect, the process is already dead") + async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo: module_info = self.module_backends[request.uid].get_info() return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))