Skip to content

Commit

Permalink
Implement ConnectionHandler.shutdown()
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 17, 2022
1 parent bee1e55 commit 96c63f0
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions hivemind/moe/server/connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,52 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
:param module_backends: a dict [UID -> ModuleBackend] with all active experts
"""

def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend]):
def __init__(self, dht: DHT, module_backends: Dict[str, ModuleBackend], *, shutdown_timeout: float = 3):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.shutdown_timeout = shutdown_timeout
self._p2p: Optional[P2P] = None

self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
self.ready = MPFuture()

def run(self):
torch.set_num_threads(1)
loop = switch_to_uvloop()
stop = asyncio.Event()
loop.add_reader(self._inner_pipe.fileno(), stop.set)

async def _run():
try:
self._p2p = await self.dht.replicate_p2p()
await self.add_p2p_handlers(self._p2p, balanced=True)

# wait forever
await asyncio.Future()

self.ready.set_result(None)
except Exception as e:
logger.error("ConnectionHandler failed to start:", exc_info=True)
self.ready.set_exception(e)
return

self.ready.set_result(None)
try:
await stop.wait()
finally:
await self.remove_p2p_handlers(self._p2p)

try:
loop.run_until_complete(_run())
except KeyboardInterrupt:
logger.debug("Caught KeyboardInterrupt, shutting down")

def shutdown(self):
if self.is_alive():
self._outer_pipe.send("_shutdown")
self.join(self.shutdown_timeout)
if self.is_alive():
logger.warning(
"ConnectionHandler did not shut down within the grace period; terminating it the hard way"
)
self.terminate()
else:
logger.warning("ConnectionHandler shutdown had no effect, the process is already dead")

async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
module_info = self.module_backends[request.uid].get_info()
return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(module_info))
Expand Down

0 comments on commit 96c63f0

Please sign in to comment.