From 0bbd15bfda83d51331559f5bf0f7b3df95953dc1 Mon Sep 17 00:00:00 2001 From: Mike Degatano Date: Thu, 25 Jul 2024 11:14:45 -0400 Subject: [PATCH] Restrict stopping core during migrations with force option (#5205) --- supervisor/api/const.py | 1 + supervisor/api/homeassistant.py | 33 ++++++++++++++++++++--- supervisor/api/host.py | 36 ++++++++++++++++++++++---- supervisor/exceptions.py | 7 +++++ tests/api/test_homeassistant.py | 46 +++++++++++++++++++++++++++++++++ tests/api/test_host.py | 42 +++++++++++++++++++++++++++++- 6 files changed, 155 insertions(+), 10 deletions(-) diff --git a/supervisor/api/const.py b/supervisor/api/const.py index 2b7b2cfbf26..be0356d31b6 100644 --- a/supervisor/api/const.py +++ b/supervisor/api/const.py @@ -36,6 +36,7 @@ ATTR_EJECTABLE = "ejectable" ATTR_FALLBACK = "fallback" ATTR_FILESYSTEMS = "filesystems" +ATTR_FORCE = "force" ATTR_GROUP_IDS = "group_ids" ATTR_IDENTIFIERS = "identifiers" ATTR_IS_ACTIVE = "is_active" diff --git a/supervisor/api/homeassistant.py b/supervisor/api/homeassistant.py index 0769ed57e8b..b68215b4a4e 100644 --- a/supervisor/api/homeassistant.py +++ b/supervisor/api/homeassistant.py @@ -1,4 +1,5 @@ """Init file for Supervisor Home Assistant RESTful API.""" + import asyncio from collections.abc import Awaitable import logging @@ -34,9 +35,9 @@ ATTR_WATCHDOG, ) from ..coresys import CoreSysAttributes -from ..exceptions import APIError +from ..exceptions import APIDBMigrationInProgress, APIError from ..validate import docker_image, network_port, version_tag -from .const import ATTR_SAFE_MODE +from .const import ATTR_FORCE, ATTR_SAFE_MODE from .utils import api_process, api_validate _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -66,6 +67,13 @@ SCHEMA_RESTART = vol.Schema( { vol.Optional(ATTR_SAFE_MODE, default=False): vol.Boolean(), + vol.Optional(ATTR_FORCE, default=False): vol.Boolean(), + } +) + +SCHEMA_STOP = vol.Schema( + { + vol.Optional(ATTR_FORCE, default=False): vol.Boolean(), } ) @@ -73,6 +81,17 @@ class APIHomeAssistant(CoreSysAttributes): """Handle RESTful API for Home Assistant functions.""" + async def _check_offline_migration(self, force: bool = False) -> None: + """Check and raise if there's an offline DB migration in progress.""" + if ( + not force + and (state := await self.sys_homeassistant.api.get_api_state()) + and state.offline_db_migration + ): + raise APIDBMigrationInProgress( + "Offline database migration in progress, try again after it has completed" + ) + @api_process async def info(self, request: web.Request) -> dict[str, Any]: """Return host information.""" @@ -154,6 +173,7 @@ async def stats(self, request: web.Request) -> dict[Any, str]: async def update(self, request: web.Request) -> None: """Update Home Assistant.""" body = await api_validate(SCHEMA_UPDATE, request) + await self._check_offline_migration() await asyncio.shield( self.sys_homeassistant.core.update( @@ -163,9 +183,12 @@ async def update(self, request: web.Request) -> None: ) @api_process - def stop(self, request: web.Request) -> Awaitable[None]: + async def stop(self, request: web.Request) -> Awaitable[None]: """Stop Home Assistant.""" - return asyncio.shield(self.sys_homeassistant.core.stop()) + body = await api_validate(SCHEMA_STOP, request) + await self._check_offline_migration(force=body[ATTR_FORCE]) + + return await asyncio.shield(self.sys_homeassistant.core.stop()) @api_process def start(self, request: web.Request) -> Awaitable[None]: @@ -176,6 +199,7 @@ def start(self, request: web.Request) -> Awaitable[None]: async def restart(self, request: web.Request) -> None: """Restart Home Assistant.""" body = await api_validate(SCHEMA_RESTART, request) + await self._check_offline_migration(force=body[ATTR_FORCE]) await asyncio.shield( self.sys_homeassistant.core.restart(safe_mode=body[ATTR_SAFE_MODE]) @@ -185,6 +209,7 @@ async def restart(self, request: web.Request) -> None: async def rebuild(self, request: web.Request) -> None: """Rebuild Home Assistant.""" body = await api_validate(SCHEMA_RESTART, request) + await self._check_offline_migration(force=body[ATTR_FORCE]) await asyncio.shield( self.sys_homeassistant.core.rebuild(safe_mode=body[ATTR_SAFE_MODE]) diff --git a/supervisor/api/host.py b/supervisor/api/host.py index 73a6c74643c..df54f77d8c5 100644 --- a/supervisor/api/host.py +++ b/supervisor/api/host.py @@ -28,7 +28,7 @@ ATTR_TIMEZONE, ) from ..coresys import CoreSysAttributes -from ..exceptions import APIError, HostLogError +from ..exceptions import APIDBMigrationInProgress, APIError, HostLogError from ..host.const import ( PARAM_BOOT_ID, PARAM_FOLLOW, @@ -46,6 +46,7 @@ ATTR_BROADCAST_MDNS, ATTR_DT_SYNCHRONIZED, ATTR_DT_UTC, + ATTR_FORCE, ATTR_IDENTIFIERS, ATTR_LLMNR_HOSTNAME, ATTR_STARTUP_TIME, @@ -64,10 +65,29 @@ SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_HOSTNAME): str}) +# pylint: disable=no-value-for-parameter +SCHEMA_SHUTDOWN = vol.Schema( + { + vol.Optional(ATTR_FORCE, default=False): vol.Boolean(), + } +) +# pylint: enable=no-value-for-parameter + class APIHost(CoreSysAttributes): """Handle RESTful API for host functions.""" + async def _check_ha_offline_migration(self, force: bool) -> None: + """Check if HA has an offline migration in progress and raise if not forced.""" + if ( + not force + and (state := await self.sys_homeassistant.api.get_api_state()) + and state.offline_db_migration + ): + raise APIDBMigrationInProgress( + "Home Assistant offline database migration in progress, please wait until complete before shutting down host" + ) + @api_process async def info(self, request): """Return host information.""" @@ -109,14 +129,20 @@ async def options(self, request): ) @api_process - def reboot(self, request): + async def reboot(self, request): """Reboot host.""" - return asyncio.shield(self.sys_host.control.reboot()) + body = await api_validate(SCHEMA_SHUTDOWN, request) + await self._check_ha_offline_migration(force=body[ATTR_FORCE]) + + return await asyncio.shield(self.sys_host.control.reboot()) @api_process - def shutdown(self, request): + async def shutdown(self, request): """Poweroff host.""" - return asyncio.shield(self.sys_host.control.shutdown()) + body = await api_validate(SCHEMA_SHUTDOWN, request) + await self._check_ha_offline_migration(force=body[ATTR_FORCE]) + + return await asyncio.shield(self.sys_host.control.shutdown()) @api_process def reload(self, request): diff --git a/supervisor/exceptions.py b/supervisor/exceptions.py index 9267289d346..54c3551fd5f 100644 --- a/supervisor/exceptions.py +++ b/supervisor/exceptions.py @@ -1,4 +1,5 @@ """Core Exceptions.""" + from collections.abc import Callable @@ -339,6 +340,12 @@ class APIAddonNotInstalled(APIError): """Not installed addon requested at addons API.""" +class APIDBMigrationInProgress(APIError): + """Service is unavailable due to an offline DB migration is in progress.""" + + status = 503 + + # Service / Discovery diff --git a/tests/api/test_homeassistant.py b/tests/api/test_homeassistant.py index c74aee358cf..eff67def0c2 100644 --- a/tests/api/test_homeassistant.py +++ b/tests/api/test_homeassistant.py @@ -8,6 +8,7 @@ import pytest from supervisor.coresys import CoreSys +from supervisor.homeassistant.api import APIState from supervisor.homeassistant.core import HomeAssistantCore from supervisor.homeassistant.module import HomeAssistant @@ -142,3 +143,48 @@ async def test_api_rebuild( assert container.remove.call_count == 4 assert container.start.call_count == 2 assert safe_mode_marker.exists() + + +@pytest.mark.parametrize("action", ["rebuild", "restart", "stop", "update"]) +async def test_migration_blocks_stopping_core( + api_client: TestClient, + coresys: CoreSys, + action: str, +): + """Test that an offline db migration in progress stops users from stopping/restarting core.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + resp = await api_client.post(f"/homeassistant/{action}") + assert resp.status == 503 + result = await resp.json() + assert ( + result["message"] + == "Offline database migration in progress, try again after it has completed" + ) + + +async def test_force_rebuild_during_migration(api_client: TestClient, coresys: CoreSys): + """Test force option rebuilds even during a migration.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + with patch.object(HomeAssistantCore, "rebuild") as rebuild: + await api_client.post("/homeassistant/rebuild", json={"force": True}) + rebuild.assert_called_once() + + +async def test_force_restart_during_migration(api_client: TestClient, coresys: CoreSys): + """Test force option restarts even during a migration.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + with patch.object(HomeAssistantCore, "restart") as restart: + await api_client.post("/homeassistant/restart", json={"force": True}) + restart.assert_called_once() + + +async def test_force_stop_during_migration(api_client: TestClient, coresys: CoreSys): + """Test force option stops even during a migration.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + with patch.object(HomeAssistantCore, "stop") as stop: + await api_client.post("/homeassistant/stop", json={"force": True}) + stop.assert_called_once() diff --git a/tests/api/test_host.py b/tests/api/test_host.py index 902701a6485..41815f22a37 100644 --- a/tests/api/test_host.py +++ b/tests/api/test_host.py @@ -1,13 +1,15 @@ """Test Host API.""" -from unittest.mock import ANY, MagicMock +from unittest.mock import ANY, MagicMock, patch from aiohttp.test_utils import TestClient import pytest from supervisor.coresys import CoreSys from supervisor.dbus.resolved import Resolved +from supervisor.homeassistant.api import APIState from supervisor.host.const import LogFormat, LogFormatter +from supervisor.host.control import SystemControl from tests.dbus_service_mocks.base import DBusServiceMock from tests.dbus_service_mocks.systemd import Systemd as SystemdService @@ -324,3 +326,41 @@ async def test_advanced_logs_errors(api_client: TestClient): content == "Invalid content type requested. Only text/plain and text/x-log supported for now." ) + + +@pytest.mark.parametrize("action", ["reboot", "shutdown"]) +async def test_migration_blocks_shutdown( + api_client: TestClient, + coresys: CoreSys, + action: str, +): + """Test that an offline db migration in progress stops users from shuting down or rebooting system.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + resp = await api_client.post(f"/host/{action}") + assert resp.status == 503 + result = await resp.json() + assert ( + result["message"] + == "Home Assistant offline database migration in progress, please wait until complete before shutting down host" + ) + + +async def test_force_reboot_during_migration(api_client: TestClient, coresys: CoreSys): + """Test force option reboots even during a migration.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + with patch.object(SystemControl, "reboot") as reboot: + await api_client.post("/host/reboot", json={"force": True}) + reboot.assert_called_once() + + +async def test_force_shutdown_during_migration( + api_client: TestClient, coresys: CoreSys +): + """Test force option shutdown even during a migration.""" + coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True) + + with patch.object(SystemControl, "shutdown") as shutdown: + await api_client.post("/host/shutdown", json={"force": True}) + shutdown.assert_called_once()