diff --git a/custom_components/stateful_scenes/switch.py b/custom_components/stateful_scenes/switch.py index 8dd86e7..76e049d 100644 --- a/custom_components/stateful_scenes/switch.py +++ b/custom_components/stateful_scenes/switch.py @@ -9,12 +9,13 @@ import voluptuous as vol from homeassistant.components.switch import PLATFORM_SCHEMA, SwitchEntity from homeassistant.config_entries import ConfigEntry -from homeassistant.const import EntityCategory +from homeassistant.const import EntityCategory, STATE_ON from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.restore_state import RestoreEntity from . import StatefulScenes from .const import ( @@ -165,7 +166,7 @@ def unregister_callback(self) -> None: self._scene.unregister_callback() -class RestoreOnDeactivate(SwitchEntity): +class RestoreOnDeactivate(SwitchEntity, RestoreEntity): """Switch entity to restore the scene on deactivation.""" _attr_name = "Restore On Deactivate" @@ -173,7 +174,7 @@ class RestoreOnDeactivate(SwitchEntity): _attr_should_poll = True _attr_assumed_state = True - def __init__(self, scene: StatefulScenes) -> None: + def __init__(self, scene: StatefulScenes.Scene) -> None: """Initialize.""" self._scene = scene self._name = f"{scene.name} Restore On Deactivate" @@ -221,3 +222,11 @@ def update(self) -> None: This is the only method that should fetch new data for Home Assistant. """ self._is_on = self._scene.restore_on_deactivate + + async def async_added_to_hass(self): + """Handle entity which will be added.""" + state = await self.async_get_last_state() + if not state: + return + self._scene.set_restore_on_deactivate(state.state == STATE_ON) + self._is_on = state.state == STATE_ON