Skip to content

Commit

Permalink
Shield service call from cancellation on REST API connection loss (#1…
Browse files Browse the repository at this point in the history
…02657)

* Shield service call from cancellation on connection loss

* add test for timeout

* Apply suggestions from code review

* Apply suggestions from code review

* fix merge

* Apply suggestions from code review
  • Loading branch information
Shulyaka authored Nov 2, 2023
1 parent 4a4d2ad commit d18b2d8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
15 changes: 11 additions & 4 deletions homeassistant/components/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Rest API for Home Assistant."""
import asyncio
from asyncio import timeout
from asyncio import shield, timeout
from collections.abc import Collection
from functools import lru_cache
from http import HTTPStatus
Expand Down Expand Up @@ -62,6 +62,7 @@
DOMAIN = "api"
STREAM_PING_PAYLOAD = "ping"
STREAM_PING_INTERVAL = 50 # seconds
SERVICE_WAIT_TIMEOUT = 10

CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)

Expand Down Expand Up @@ -388,11 +389,17 @@ def _async_save_changed_entities(
)

try:
await hass.services.async_call(
domain, service, data, blocking=True, context=context
)
async with timeout(SERVICE_WAIT_TIMEOUT):
# shield the service call from cancellation on connection drop
await shield(
hass.services.async_call(
domain, service, data, blocking=True, context=context
)
)
except (vol.Invalid, ServiceNotFound) as ex:
raise HTTPBadRequest() from ex
except TimeoutError:
pass
finally:
cancel_listen()

Expand Down
25 changes: 25 additions & 0 deletions tests/components/api/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,31 @@ def listener(service_call):
assert state["attributes"] == {"data": 1}


async def test_api_call_service_timeout(
hass: HomeAssistant, mock_api_client: TestClient
) -> None:
"""Test if the API does not fail on long running services."""
test_value = []

fut = hass.loop.create_future()

async def listener(service_call):
"""Wait and return after mock_api_client.post finishes."""
value = await fut
test_value.append(value)

hass.services.async_register("test_domain", "test_service", listener)

with patch("homeassistant.components.api.SERVICE_WAIT_TIMEOUT", 0):
await mock_api_client.post("/api/services/test_domain/test_service")
assert len(test_value) == 0
fut.set_result(1)
await hass.async_block_till_done()

assert len(test_value) == 1
assert test_value[0] == 1


async def test_api_template(hass: HomeAssistant, mock_api_client: TestClient) -> None:
"""Test the template API."""
hass.states.async_set("sensor.temperature", 10)
Expand Down

0 comments on commit d18b2d8

Please sign in to comment.