Skip to content

Commit

Permalink
Restrict stopping core during migrations with force option (#5205)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdegat01 authored Jul 25, 2024
1 parent 591b9a4 commit 0bbd15b
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 10 deletions.
1 change: 1 addition & 0 deletions supervisor/api/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ATTR_EJECTABLE = "ejectable"
ATTR_FALLBACK = "fallback"
ATTR_FILESYSTEMS = "filesystems"
ATTR_FORCE = "force"
ATTR_GROUP_IDS = "group_ids"
ATTR_IDENTIFIERS = "identifiers"
ATTR_IS_ACTIVE = "is_active"
Expand Down
33 changes: 29 additions & 4 deletions supervisor/api/homeassistant.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Init file for Supervisor Home Assistant RESTful API."""

import asyncio
from collections.abc import Awaitable
import logging
Expand Down Expand Up @@ -34,9 +35,9 @@
ATTR_WATCHDOG,
)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError
from ..exceptions import APIDBMigrationInProgress, APIError
from ..validate import docker_image, network_port, version_tag
from .const import ATTR_SAFE_MODE
from .const import ATTR_FORCE, ATTR_SAFE_MODE
from .utils import api_process, api_validate

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,13 +67,31 @@
SCHEMA_RESTART = vol.Schema(
{
vol.Optional(ATTR_SAFE_MODE, default=False): vol.Boolean(),
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)

SCHEMA_STOP = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)


class APIHomeAssistant(CoreSysAttributes):
"""Handle RESTful API for Home Assistant functions."""

async def _check_offline_migration(self, force: bool = False) -> None:
"""Check and raise if there's an offline DB migration in progress."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Offline database migration in progress, try again after it has completed"
)

@api_process
async def info(self, request: web.Request) -> dict[str, Any]:
"""Return host information."""
Expand Down Expand Up @@ -154,6 +173,7 @@ async def stats(self, request: web.Request) -> dict[Any, str]:
async def update(self, request: web.Request) -> None:
"""Update Home Assistant."""
body = await api_validate(SCHEMA_UPDATE, request)
await self._check_offline_migration()

await asyncio.shield(
self.sys_homeassistant.core.update(
Expand All @@ -163,9 +183,12 @@ async def update(self, request: web.Request) -> None:
)

@api_process
def stop(self, request: web.Request) -> Awaitable[None]:
async def stop(self, request: web.Request) -> Awaitable[None]:
"""Stop Home Assistant."""
return asyncio.shield(self.sys_homeassistant.core.stop())
body = await api_validate(SCHEMA_STOP, request)
await self._check_offline_migration(force=body[ATTR_FORCE])

return await asyncio.shield(self.sys_homeassistant.core.stop())

@api_process
def start(self, request: web.Request) -> Awaitable[None]:
Expand All @@ -176,6 +199,7 @@ def start(self, request: web.Request) -> Awaitable[None]:
async def restart(self, request: web.Request) -> None:
"""Restart Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])

await asyncio.shield(
self.sys_homeassistant.core.restart(safe_mode=body[ATTR_SAFE_MODE])
Expand All @@ -185,6 +209,7 @@ async def restart(self, request: web.Request) -> None:
async def rebuild(self, request: web.Request) -> None:
"""Rebuild Home Assistant."""
body = await api_validate(SCHEMA_RESTART, request)
await self._check_offline_migration(force=body[ATTR_FORCE])

await asyncio.shield(
self.sys_homeassistant.core.rebuild(safe_mode=body[ATTR_SAFE_MODE])
Expand Down
36 changes: 31 additions & 5 deletions supervisor/api/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ATTR_TIMEZONE,
)
from ..coresys import CoreSysAttributes
from ..exceptions import APIError, HostLogError
from ..exceptions import APIDBMigrationInProgress, APIError, HostLogError
from ..host.const import (
PARAM_BOOT_ID,
PARAM_FOLLOW,
Expand All @@ -46,6 +46,7 @@
ATTR_BROADCAST_MDNS,
ATTR_DT_SYNCHRONIZED,
ATTR_DT_UTC,
ATTR_FORCE,
ATTR_IDENTIFIERS,
ATTR_LLMNR_HOSTNAME,
ATTR_STARTUP_TIME,
Expand All @@ -64,10 +65,29 @@

SCHEMA_OPTIONS = vol.Schema({vol.Optional(ATTR_HOSTNAME): str})

# pylint: disable=no-value-for-parameter
SCHEMA_SHUTDOWN = vol.Schema(
{
vol.Optional(ATTR_FORCE, default=False): vol.Boolean(),
}
)
# pylint: enable=no-value-for-parameter


class APIHost(CoreSysAttributes):
"""Handle RESTful API for host functions."""

async def _check_ha_offline_migration(self, force: bool) -> None:
"""Check if HA has an offline migration in progress and raise if not forced."""
if (
not force
and (state := await self.sys_homeassistant.api.get_api_state())
and state.offline_db_migration
):
raise APIDBMigrationInProgress(
"Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)

@api_process
async def info(self, request):
"""Return host information."""
Expand Down Expand Up @@ -109,14 +129,20 @@ async def options(self, request):
)

@api_process
def reboot(self, request):
async def reboot(self, request):
"""Reboot host."""
return asyncio.shield(self.sys_host.control.reboot())
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])

return await asyncio.shield(self.sys_host.control.reboot())

@api_process
def shutdown(self, request):
async def shutdown(self, request):
"""Poweroff host."""
return asyncio.shield(self.sys_host.control.shutdown())
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])

return await asyncio.shield(self.sys_host.control.shutdown())

@api_process
def reload(self, request):
Expand Down
7 changes: 7 additions & 0 deletions supervisor/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Core Exceptions."""

from collections.abc import Callable


Expand Down Expand Up @@ -339,6 +340,12 @@ class APIAddonNotInstalled(APIError):
"""Not installed addon requested at addons API."""


class APIDBMigrationInProgress(APIError):
"""Service is unavailable due to an offline DB migration is in progress."""

status = 503


# Service / Discovery


Expand Down
46 changes: 46 additions & 0 deletions tests/api/test_homeassistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest

from supervisor.coresys import CoreSys
from supervisor.homeassistant.api import APIState
from supervisor.homeassistant.core import HomeAssistantCore
from supervisor.homeassistant.module import HomeAssistant

Expand Down Expand Up @@ -142,3 +143,48 @@ async def test_api_rebuild(
assert container.remove.call_count == 4
assert container.start.call_count == 2
assert safe_mode_marker.exists()


@pytest.mark.parametrize("action", ["rebuild", "restart", "stop", "update"])
async def test_migration_blocks_stopping_core(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from stopping/restarting core."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

resp = await api_client.post(f"/homeassistant/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Offline database migration in progress, try again after it has completed"
)


async def test_force_rebuild_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option rebuilds even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

with patch.object(HomeAssistantCore, "rebuild") as rebuild:
await api_client.post("/homeassistant/rebuild", json={"force": True})
rebuild.assert_called_once()


async def test_force_restart_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option restarts even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

with patch.object(HomeAssistantCore, "restart") as restart:
await api_client.post("/homeassistant/restart", json={"force": True})
restart.assert_called_once()


async def test_force_stop_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option stops even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

with patch.object(HomeAssistantCore, "stop") as stop:
await api_client.post("/homeassistant/stop", json={"force": True})
stop.assert_called_once()
42 changes: 41 additions & 1 deletion tests/api/test_host.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Test Host API."""

from unittest.mock import ANY, MagicMock
from unittest.mock import ANY, MagicMock, patch

from aiohttp.test_utils import TestClient
import pytest

from supervisor.coresys import CoreSys
from supervisor.dbus.resolved import Resolved
from supervisor.homeassistant.api import APIState
from supervisor.host.const import LogFormat, LogFormatter
from supervisor.host.control import SystemControl

from tests.dbus_service_mocks.base import DBusServiceMock
from tests.dbus_service_mocks.systemd import Systemd as SystemdService
Expand Down Expand Up @@ -324,3 +326,41 @@ async def test_advanced_logs_errors(api_client: TestClient):
content
== "Invalid content type requested. Only text/plain and text/x-log supported for now."
)


@pytest.mark.parametrize("action", ["reboot", "shutdown"])
async def test_migration_blocks_shutdown(
api_client: TestClient,
coresys: CoreSys,
action: str,
):
"""Test that an offline db migration in progress stops users from shuting down or rebooting system."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

resp = await api_client.post(f"/host/{action}")
assert resp.status == 503
result = await resp.json()
assert (
result["message"]
== "Home Assistant offline database migration in progress, please wait until complete before shutting down host"
)


async def test_force_reboot_during_migration(api_client: TestClient, coresys: CoreSys):
"""Test force option reboots even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

with patch.object(SystemControl, "reboot") as reboot:
await api_client.post("/host/reboot", json={"force": True})
reboot.assert_called_once()


async def test_force_shutdown_during_migration(
api_client: TestClient, coresys: CoreSys
):
"""Test force option shutdown even during a migration."""
coresys.homeassistant.api.get_api_state.return_value = APIState("NOT_RUNNING", True)

with patch.object(SystemControl, "shutdown") as shutdown:
await api_client.post("/host/shutdown", json={"force": True})
shutdown.assert_called_once()

0 comments on commit 0bbd15b

Please sign in to comment.