From d18b2d8748cf8776af9c9b10cebae101dc282e63 Mon Sep 17 00:00:00 2001 From: Denis Shulyaka Date: Thu, 2 Nov 2023 14:58:26 +0300 Subject: [PATCH] Shield service call from cancellation on REST API connection loss (#102657) * Shield service call from cancellation on connection loss * add test for timeout * Apply suggestions from code review * Apply suggestions from code review * fix merge * Apply suggestions from code review --- homeassistant/components/api/__init__.py | 15 ++++++++++---- tests/components/api/test_init.py | 25 ++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/api/__init__.py b/homeassistant/components/api/__init__.py index 077e5ec909315..a9efda904822c 100644 --- a/homeassistant/components/api/__init__.py +++ b/homeassistant/components/api/__init__.py @@ -1,6 +1,6 @@ """Rest API for Home Assistant.""" import asyncio -from asyncio import timeout +from asyncio import shield, timeout from collections.abc import Collection from functools import lru_cache from http import HTTPStatus @@ -62,6 +62,7 @@ DOMAIN = "api" STREAM_PING_PAYLOAD = "ping" STREAM_PING_INTERVAL = 50 # seconds +SERVICE_WAIT_TIMEOUT = 10 CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) @@ -388,11 +389,17 @@ def _async_save_changed_entities( ) try: - await hass.services.async_call( - domain, service, data, blocking=True, context=context - ) + async with timeout(SERVICE_WAIT_TIMEOUT): + # shield the service call from cancellation on connection drop + await shield( + hass.services.async_call( + domain, service, data, blocking=True, context=context + ) + ) except (vol.Invalid, ServiceNotFound) as ex: raise HTTPBadRequest() from ex + except TimeoutError: + pass finally: cancel_listen() diff --git a/tests/components/api/test_init.py b/tests/components/api/test_init.py index 2d5705403413e..f97b55c3ede53 100644 --- a/tests/components/api/test_init.py +++ b/tests/components/api/test_init.py @@ -352,6 +352,31 @@ def listener(service_call): assert state["attributes"] == {"data": 1} +async def test_api_call_service_timeout( + hass: HomeAssistant, mock_api_client: TestClient +) -> None: + """Test if the API does not fail on long running services.""" + test_value = [] + + fut = hass.loop.create_future() + + async def listener(service_call): + """Wait and return after mock_api_client.post finishes.""" + value = await fut + test_value.append(value) + + hass.services.async_register("test_domain", "test_service", listener) + + with patch("homeassistant.components.api.SERVICE_WAIT_TIMEOUT", 0): + await mock_api_client.post("/api/services/test_domain/test_service") + assert len(test_value) == 0 + fut.set_result(1) + await hass.async_block_till_done() + + assert len(test_value) == 1 + assert test_value[0] == 1 + + async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None: """Test the template API.""" hass.states.async_set("sensor.temperature", 10)