diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index de584e1a8..2c83199bd 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -197,3 +197,33 @@ async def test(): loop = asyncio.new_event_loop() loop.run_until_complete(test()) + + +@pytest.mark.parametrize("mode", ("auto", "on")) +@pytest.mark.parametrize("raise_exception", (True, False)) +def test_lifespan_with_failed_shutdown(mode, raise_exception): + async def app(scope, receive, send): + message = await receive() + assert message["type"] == "lifespan.startup" + await send({"type": "lifespan.startup.complete"}) + message = await receive() + assert message["type"] == "lifespan.shutdown" + await send({"type": "lifespan.shutdown.failed"}) + + if raise_exception: + # App should be able to re-raise an exception if startup failed. + raise RuntimeError() + + async def test(): + config = Config(app=app, lifespan=mode) + lifespan = LifespanOn(config) + + await lifespan.startup() + assert not lifespan.startup_failed + await lifespan.shutdown() + assert lifespan.shutdown_failed + assert lifespan.error_occured is raise_exception + assert lifespan.should_exit + + loop = asyncio.new_event_loop() + loop.run_until_complete(test()) diff --git a/uvicorn/lifespan/on.py b/uvicorn/lifespan/on.py index 0f0644fef..a8deb7644 100644 --- a/uvicorn/lifespan/on.py +++ b/uvicorn/lifespan/on.py @@ -20,6 +20,7 @@ def __init__(self, config: Config) -> None: self.receive_queue: "Queue[LifespanReceiveMessage]" = asyncio.Queue() self.error_occured = False self.startup_failed = False + self.shutdown_failed = False self.should_exit = False async def startup(self) -> None: @@ -43,7 +44,14 @@ async def shutdown(self) -> None: self.logger.info("Waiting for application shutdown.") await self.receive_queue.put({"type": "lifespan.shutdown"}) await self.shutdown_event.wait() - self.logger.info("Application shutdown complete.") + + if self.shutdown_failed or ( + self.error_occured and self.config.lifespan == "on" + ): + self.logger.error("Application shutdown failed. Exiting.") + self.should_exit = True + else: + self.logger.info("Application shutdown complete.") async def main(self) -> None: try: @@ -56,7 +64,7 @@ async def main(self) -> None: except BaseException as exc: self.asgi = None self.error_occured = True - if self.startup_failed: + if self.startup_failed or self.shutdown_failed: return if self.config.lifespan == "auto": msg = "ASGI 'lifespan' protocol appears unsupported." @@ -73,6 +81,7 @@ async def send(self, message: LifespanSendMessage) -> None: "lifespan.startup.complete", "lifespan.startup.failed", "lifespan.shutdown.complete", + "lifespan.shutdown.failed", ) if message["type"] == "lifespan.startup.complete": @@ -93,5 +102,13 @@ async def send(self, message: LifespanSendMessage) -> None: assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR self.shutdown_event.set() + elif message["type"] == "lifespan.shutdown.failed": + assert self.startup_event.is_set(), STATE_TRANSITION_ERROR + assert not self.shutdown_event.is_set(), STATE_TRANSITION_ERROR + self.shutdown_event.set() + self.shutdown_failed = True + if message.get("message"): + self.logger.error(message["message"]) + async def receive(self) -> LifespanReceiveMessage: return await self.receive_queue.get()