Skip to content

Commit

Permalink
Fix cancellation leaking upward from the timeout util (#129003)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored and frenck committed Oct 25, 2024
1 parent 9dd8c0c commit 096d506
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 4 deletions.
33 changes: 30 additions & 3 deletions homeassistant/util/timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ZONE_GLOBAL = "global"


class _State(str, enum.Enum):
class _State(enum.Enum):
"""States of a task."""

INIT = "INIT"
Expand Down Expand Up @@ -160,11 +160,16 @@ def __init__(
self._wait_zone: asyncio.Event = asyncio.Event()
self._state: _State = _State.INIT
self._cool_down: float = cool_down
self._cancelling = 0

async def __aenter__(self) -> Self:
self._manager.global_tasks.append(self)
self._start_timer()
self._state = _State.ACTIVE
# Remember if the task was already cancelling
# so when we __aexit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = self._task.cancelling()
return self

async def __aexit__(
Expand All @@ -177,7 +182,15 @@ async def __aexit__(
self._manager.global_tasks.remove(self)

# Timeout on exit
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
# The timeout was hit, and the task was cancelled
# so we need to uncancel the task since the cancellation
# should not leak out of the context manager
if self._task.uncancel() > self._cancelling:
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
return None
raise TimeoutError

self._state = _State.EXIT
Expand Down Expand Up @@ -266,6 +279,7 @@ def __init__(
self._time_left: float = timeout
self._expiration_time: float | None = None
self._timeout_handler: asyncio.Handle | None = None
self._cancelling = 0

@property
def state(self) -> _State:
Expand All @@ -280,6 +294,11 @@ async def __aenter__(self) -> Self:
if self._zone.freezes_done:
self._start_timer()

# Remember if the task was already cancelling
# so when we __aexit__ we can decide if we should
# raise asyncio.TimeoutError or let the cancellation propagate
self._cancelling = self._task.cancelling()

return self

async def __aexit__(
Expand All @@ -292,7 +311,15 @@ async def __aexit__(
self._stop_timer()

# Timeout on exit
if exc_type is asyncio.CancelledError and self.state == _State.TIMEOUT:
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
# The timeout was hit, and the task was cancelled
# so we need to uncancel the task since the cancellation
# should not leak out of the context manager
if self._task.uncancel() > self._cancelling:
# If the task was already cancelling don't raise
# asyncio.TimeoutError and instead return None
# to allow the cancellation to propagate
return None
raise TimeoutError

self._state = _State.EXIT
Expand Down
114 changes: 113 additions & 1 deletion tests/util/test_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,62 @@ async def test_simple_global_timeout_freeze_with_executor_job(
await hass.async_add_executor_job(time.sleep, 0.3)


async def test_simple_global_timeout_does_not_leak_upward(
hass: HomeAssistant,
) -> None:
"""Test a global timeout does not leak upward."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None

with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)

assert cancelling_inside_timeout == 0
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0


async def test_simple_global_timeout_does_swallow_cancellation(
hass: HomeAssistant,
) -> None:
"""Test a global timeout does not swallow cancellation."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None

async def task_with_timeout() -> None:
nonlocal cancelling_inside_timeout
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
cancelling_inside_timeout = new_task.cancelling()
async with timeout.async_timeout(0.1):
await asyncio.sleep(0.3)

# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0

task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1

assert cancelling_inside_timeout == 0
# Cancellation should not leak into the current task
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1


async def test_simple_global_timeout_freeze_reset() -> None:
"""Test a simple global timeout freeze reset."""
timeout = TimeoutManager()
Expand All @@ -166,6 +222,62 @@ async def test_simple_zone_timeout() -> None:
await asyncio.sleep(0.3)


async def test_simple_zone_timeout_does_not_leak_upward(
hass: HomeAssistant,
) -> None:
"""Test a zone timeout does not leak upward."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None

with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1, "test"):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)

assert cancelling_inside_timeout == 0
# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0


async def test_simple_zone_timeout_does_swallow_cancellation(
hass: HomeAssistant,
) -> None:
"""Test a zone timeout does not swallow cancellation."""
timeout = TimeoutManager()
current_task = asyncio.current_task()
assert current_task is not None
cancelling_inside_timeout = None

async def task_with_timeout() -> None:
nonlocal cancelling_inside_timeout
new_task = asyncio.current_task()
assert new_task is not None
with pytest.raises(asyncio.TimeoutError): # noqa: PT012
async with timeout.async_timeout(0.1, "test"):
cancelling_inside_timeout = current_task.cancelling()
await asyncio.sleep(0.3)

# After the context manager exits, the task should no longer be cancelling
assert current_task.cancelling() == 0

task = asyncio.create_task(task_with_timeout())
await asyncio.sleep(0)
task.cancel()
assert task.cancelling() == 1

# Cancellation should not leak into the current task
assert cancelling_inside_timeout == 0
assert current_task.cancelling() == 0
# Cancellation should not be swallowed if the task is cancelled
# and it also times out
await asyncio.sleep(0)
with pytest.raises(asyncio.CancelledError):
await task
assert task.cancelling() == 1


async def test_multiple_zone_timeout() -> None:
"""Test a simple zone timeout."""
timeout = TimeoutManager()
Expand Down Expand Up @@ -327,7 +439,7 @@ async def test_simple_zone_timeout_freeze_without_timeout_exeption() -> None:
await asyncio.sleep(0.4)


async def test_simple_zone_timeout_zone_with_timeout_exeption() -> None:
async def test_simple_zone_timeout_zone_with_timeout_exception() -> None:
"""Test a simple zone timeout freeze on a zone that does not have a timeout set."""
timeout = TimeoutManager()

Expand Down

0 comments on commit 096d506

Please sign in to comment.