From 096d50617f423432ce2e2a1bf4934500d97d2c4d Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 23 Oct 2024 12:00:01 -1000 Subject: [PATCH] Fix cancellation leaking upward from the timeout util (#129003) --- homeassistant/util/timeout.py | 33 +++++++++- tests/util/test_timeout.py | 114 +++++++++++++++++++++++++++++++++- 2 files changed, 143 insertions(+), 4 deletions(-) diff --git a/homeassistant/util/timeout.py b/homeassistant/util/timeout.py index 821f502694bb57..ddabdf2746d7f2 100644 --- a/homeassistant/util/timeout.py +++ b/homeassistant/util/timeout.py @@ -16,7 +16,7 @@ ZONE_GLOBAL = "global" -class _State(str, enum.Enum): +class _State(enum.Enum): """States of a task.""" INIT = "INIT" @@ -160,11 +160,16 @@ def __init__( self._wait_zone: asyncio.Event = asyncio.Event() self._state: _State = _State.INIT self._cool_down: float = cool_down + self._cancelling = 0 async def __aenter__(self) -> Self: self._manager.global_tasks.append(self) self._start_timer() self._state = _State.ACTIVE + # Remember if the task was already cancelling + # so when we __aexit__ we can decide if we should + # raise asyncio.TimeoutError or let the cancellation propagate + self._cancelling = self._task.cancelling() return self async def __aexit__( @@ -177,7 +182,15 @@ async def __aexit__( self._manager.global_tasks.remove(self) # Timeout on exit - if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT: + if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT: + # The timeout was hit, and the task was cancelled + # so we need to uncancel the task since the cancellation + # should not leak out of the context manager + if self._task.uncancel() > self._cancelling: + # If the task was already cancelling don't raise + # asyncio.TimeoutError and instead return None + # to allow the cancellation to propagate + return None raise TimeoutError self._state = _State.EXIT @@ -266,6 +279,7 @@ def __init__( self._time_left: float = timeout self._expiration_time: float | None = None self._timeout_handler: asyncio.Handle | None = None + self._cancelling = 0 @property def state(self) -> _State: @@ -280,6 +294,11 @@ async def __aenter__(self) -> Self: if self._zone.freezes_done: self._start_timer() + # Remember if the task was already cancelling + # so when we __aexit__ we can decide if we should + # raise asyncio.TimeoutError or let the cancellation propagate + self._cancelling = self._task.cancelling() + return self async def __aexit__( @@ -292,7 +311,15 @@ async def __aexit__( self._stop_timer() # Timeout on exit - if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT: + if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT: + # The timeout was hit, and the task was cancelled + # so we need to uncancel the task since the cancellation + # should not leak out of the context manager + if self._task.uncancel() > self._cancelling: + # If the task was already cancelling don't raise + # asyncio.TimeoutError and instead return None + # to allow the cancellation to propagate + return None raise TimeoutError self._state = _State.EXIT diff --git a/tests/util/test_timeout.py b/tests/util/test_timeout.py index 1c4b06d99b4533..5e8261c4c02b79 100644 --- a/tests/util/test_timeout.py +++ b/tests/util/test_timeout.py @@ -146,6 +146,62 @@ async def test_simple_global_timeout_freeze_with_executor_job( await hass.async_add_executor_job(time.sleep, 0.3) +async def test_simple_global_timeout_does_not_leak_upward( + hass: HomeAssistant, +) -> None: + """Test a global timeout does not leak upward.""" + timeout = TimeoutManager() + current_task = asyncio.current_task() + assert current_task is not None + cancelling_inside_timeout = None + + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + async with timeout.async_timeout(0.1): + cancelling_inside_timeout = current_task.cancelling() + await asyncio.sleep(0.3) + + assert cancelling_inside_timeout == 0 + # After the context manager exits, the task should no longer be cancelling + assert current_task.cancelling() == 0 + + +async def test_simple_global_timeout_does_swallow_cancellation( + hass: HomeAssistant, +) -> None: + """Test a global timeout does not swallow cancellation.""" + timeout = TimeoutManager() + current_task = asyncio.current_task() + assert current_task is not None + cancelling_inside_timeout = None + + async def task_with_timeout() -> None: + nonlocal cancelling_inside_timeout + new_task = asyncio.current_task() + assert new_task is not None + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + cancelling_inside_timeout = new_task.cancelling() + async with timeout.async_timeout(0.1): + await asyncio.sleep(0.3) + + # After the context manager exits, the task should no longer be cancelling + assert current_task.cancelling() == 0 + + task = asyncio.create_task(task_with_timeout()) + await asyncio.sleep(0) + task.cancel() + assert task.cancelling() == 1 + + assert cancelling_inside_timeout == 0 + # Cancellation should not leak into the current task + assert current_task.cancelling() == 0 + # Cancellation should not be swallowed if the task is cancelled + # and it also times out + await asyncio.sleep(0) + with pytest.raises(asyncio.CancelledError): + await task + assert task.cancelling() == 1 + + async def test_simple_global_timeout_freeze_reset() -> None: """Test a simple global timeout freeze reset.""" timeout = TimeoutManager() @@ -166,6 +222,62 @@ async def test_simple_zone_timeout() -> None: await asyncio.sleep(0.3) +async def test_simple_zone_timeout_does_not_leak_upward( + hass: HomeAssistant, +) -> None: + """Test a zone timeout does not leak upward.""" + timeout = TimeoutManager() + current_task = asyncio.current_task() + assert current_task is not None + cancelling_inside_timeout = None + + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + async with timeout.async_timeout(0.1, "test"): + cancelling_inside_timeout = current_task.cancelling() + await asyncio.sleep(0.3) + + assert cancelling_inside_timeout == 0 + # After the context manager exits, the task should no longer be cancelling + assert current_task.cancelling() == 0 + + +async def test_simple_zone_timeout_does_swallow_cancellation( + hass: HomeAssistant, +) -> None: + """Test a zone timeout does not swallow cancellation.""" + timeout = TimeoutManager() + current_task = asyncio.current_task() + assert current_task is not None + cancelling_inside_timeout = None + + async def task_with_timeout() -> None: + nonlocal cancelling_inside_timeout + new_task = asyncio.current_task() + assert new_task is not None + with pytest.raises(asyncio.TimeoutError): # noqa: PT012 + async with timeout.async_timeout(0.1, "test"): + cancelling_inside_timeout = current_task.cancelling() + await asyncio.sleep(0.3) + + # After the context manager exits, the task should no longer be cancelling + assert current_task.cancelling() == 0 + + task = asyncio.create_task(task_with_timeout()) + await asyncio.sleep(0) + task.cancel() + assert task.cancelling() == 1 + + # Cancellation should not leak into the current task + assert cancelling_inside_timeout == 0 + assert current_task.cancelling() == 0 + # Cancellation should not be swallowed if the task is cancelled + # and it also times out + await asyncio.sleep(0) + with pytest.raises(asyncio.CancelledError): + await task + assert task.cancelling() == 1 + + async def test_multiple_zone_timeout() -> None: """Test a simple zone timeout.""" timeout = TimeoutManager() @@ -327,7 +439,7 @@ async def test_simple_zone_timeout_freeze_without_timeout_exeption() -> None: await asyncio.sleep(0.4) -async def test_simple_zone_timeout_zone_with_timeout_exeption() -> None: +async def test_simple_zone_timeout_zone_with_timeout_exception() -> None: """Test a simple zone timeout freeze on a zone that does not have a timeout set.""" timeout = TimeoutManager()