diff --git a/homeassistant/components/evohome/__init__.py b/homeassistant/components/evohome/__init__.py index 13673caebb370..4cf8561fc3b2b 100644 --- a/homeassistant/components/evohome/__init__.py +++ b/homeassistant/components/evohome/__init__.py @@ -14,20 +14,14 @@ from evohomeasync.schema import SZ_SESSION_ID import evohomeasync2 as evo from evohomeasync2.schema.const import ( - SZ_ALLOWED_SYSTEM_MODES, SZ_AUTO_WITH_RESET, SZ_CAN_BE_TEMPORARY, - SZ_GATEWAY_ID, - SZ_GATEWAY_INFO, SZ_HEAT_SETPOINT, - SZ_LOCATION_ID, - SZ_LOCATION_INFO, SZ_SETPOINT_STATUS, SZ_STATE_STATUS, SZ_SYSTEM_MODE, SZ_SYSTEM_MODE_STATUS, SZ_TIME_UNTIL, - SZ_TIME_ZONE, SZ_TIMING_MODE, SZ_UNTIL, ) @@ -50,13 +44,14 @@ async_dispatcher_send, ) from homeassistant.helpers.entity import Entity -from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.service import verify_domain_control from homeassistant.helpers.storage import Store from homeassistant.helpers.typing import ConfigType +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator import homeassistant.util.dt as dt_util from .const import ( + ACCESS_TOKEN, ACCESS_TOKEN_EXPIRES, ATTR_DURATION_DAYS, ATTR_DURATION_HOURS, @@ -65,12 +60,11 @@ ATTR_ZONE_TEMP, CONF_LOCATION_IDX, DOMAIN, - GWS, + REFRESH_TOKEN, SCAN_INTERVAL_DEFAULT, SCAN_INTERVAL_MINIMUM, STORAGE_KEY, STORAGE_VER, - TCS, USER_DATA, EvoService, ) @@ -79,6 +73,7 @@ convert_dict, convert_until, dt_aware_to_naive, + dt_local_to_aware, handle_evo_exception, ) @@ -118,91 +113,158 @@ ) -async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: - """Create a (EMEA/EU-based) Honeywell TCC system.""" +class EvoSession: + """Class for evohome client instantiation & authentication.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize the evohome broker and its data structure.""" + + self.hass = hass + + self._session = async_get_clientsession(hass) + self._store = Store[dict[str, Any]](hass, STORAGE_VER, STORAGE_KEY) + + # the main client, which uses the newer API + self.client_v2: evo.EvohomeClient | None = None + self._tokens: dict[str, Any] = {} + + # the older client can be used to obtain high-precision temps (only) + self.client_v1: ev1.EvohomeClient | None = None + self.session_id: str | None = None + + async def authenticate(self, username: str, password: str) -> None: + """Check the user credentials against the web API. - async def load_auth_tokens(store: Store) -> tuple[dict, dict | None]: - app_storage = await store.async_load() - tokens = dict(app_storage or {}) + Will raise evo.AuthenticationFailed if the credentials are invalid. + """ + + if ( + self.client_v2 is None + or username != self.client_v2.username + or password != self.client_v2.password + ): + await self._load_auth_tokens(username) - if tokens.pop(CONF_USERNAME, None) != config[DOMAIN][CONF_USERNAME]: + client_v2 = evo.EvohomeClient( + username, + password, + **self._tokens, + session=self._session, + ) + + else: # force a re-authentication + client_v2 = self.client_v2 + client_v2._user_account = None # noqa: SLF001 + + await client_v2.login() + await self.save_auth_tokens() + + self.client_v2 = client_v2 + + self.client_v1 = ev1.EvohomeClient( + username, + password, + session_id=self.session_id, + session=self._session, + ) + + async def _load_auth_tokens(self, username: str) -> None: + """Load access tokens and session_id from the store and validate them. + + Sets self._tokens and self._session_id to the latest values. + """ + + app_storage: dict[str, Any] = dict(await self._store.async_load() or {}) + + if app_storage.pop(CONF_USERNAME, None) != username: # any tokens won't be valid, and store might be corrupt - await store.async_save({}) - return ({}, {}) + await self._store.async_save({}) + + self.session_id = None + self._tokens = {} + + return # evohomeasync2 requires naive/local datetimes as strings - if tokens.get(ACCESS_TOKEN_EXPIRES) is not None and ( - expires := dt_util.parse_datetime(tokens[ACCESS_TOKEN_EXPIRES]) + if app_storage.get(ACCESS_TOKEN_EXPIRES) is not None and ( + expires := dt_util.parse_datetime(app_storage[ACCESS_TOKEN_EXPIRES]) ): - tokens[ACCESS_TOKEN_EXPIRES] = dt_aware_to_naive(expires) + app_storage[ACCESS_TOKEN_EXPIRES] = dt_aware_to_naive(expires) - user_data = tokens.pop(USER_DATA, {}) - return (tokens, user_data) + user_data: dict[str, str] = app_storage.pop(USER_DATA, {}) - store = Store[dict[str, Any]](hass, STORAGE_VER, STORAGE_KEY) - tokens, user_data = await load_auth_tokens(store) + self.session_id = user_data.get(SZ_SESSION_ID) + self._tokens = app_storage - client_v2 = evo.EvohomeClient( - config[DOMAIN][CONF_USERNAME], - config[DOMAIN][CONF_PASSWORD], - **tokens, - session=async_get_clientsession(hass), - ) + async def save_auth_tokens(self) -> None: + """Save access tokens and session_id to the store. + + Sets self._tokens and self._session_id to the latest values. + """ + + if self.client_v2 is None: + await self._store.async_save({}) + return + + # evohomeasync2 uses naive/local datetimes + access_token_expires = dt_local_to_aware( + self.client_v2.access_token_expires # type: ignore[arg-type] + ) + + self._tokens = { + CONF_USERNAME: self.client_v2.username, + REFRESH_TOKEN: self.client_v2.refresh_token, + ACCESS_TOKEN: self.client_v2.access_token, + ACCESS_TOKEN_EXPIRES: access_token_expires.isoformat(), + } + + self.session_id = self.client_v1.broker.session_id if self.client_v1 else None + + app_storage = self._tokens + if self.client_v1: + app_storage[USER_DATA] = {SZ_SESSION_ID: self.session_id} + + await self._store.async_save(app_storage) + + +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: + """Set up the Evohome integration.""" + + sess = EvoSession(hass) try: - await client_v2.login() + await sess.authenticate( + config[DOMAIN][CONF_USERNAME], + config[DOMAIN][CONF_PASSWORD], + ) + except evo.AuthenticationFailed as err: handle_evo_exception(err) return False + finally: config[DOMAIN][CONF_PASSWORD] = "REDACTED" - assert isinstance(client_v2.installation_info, list) # mypy + broker = EvoBroker(sess) - loc_idx = config[DOMAIN][CONF_LOCATION_IDX] - try: - loc_config = client_v2.installation_info[loc_idx] - except IndexError: - _LOGGER.error( - ( - "Config error: '%s' = %s, but the valid range is 0-%s. " - "Unable to continue. Fix any configuration errors and restart HA" - ), - CONF_LOCATION_IDX, - loc_idx, - len(client_v2.installation_info) - 1, - ) + if not broker.validate_location( + config[DOMAIN][CONF_LOCATION_IDX], + ): return False - if _LOGGER.isEnabledFor(logging.DEBUG): - loc_info = { - SZ_LOCATION_ID: loc_config[SZ_LOCATION_INFO][SZ_LOCATION_ID], - SZ_TIME_ZONE: loc_config[SZ_LOCATION_INFO][SZ_TIME_ZONE], - } - gwy_info = { - SZ_GATEWAY_ID: loc_config[GWS][0][SZ_GATEWAY_INFO][SZ_GATEWAY_ID], - TCS: loc_config[GWS][0][TCS], - } - _config = { - SZ_LOCATION_INFO: loc_info, - GWS: [{SZ_GATEWAY_INFO: gwy_info}], - } - _LOGGER.debug("Config = %s", _config) - - client_v1 = ev1.EvohomeClient( - client_v2.username, - client_v2.password, - session_id=user_data.get(SZ_SESSION_ID) if user_data else None, # STORAGE_VER 1 - session=async_get_clientsession(hass), + coordinator = DataUpdateCoordinator( + hass, + _LOGGER, + name=f"{DOMAIN}_coordinator", + update_interval=config[DOMAIN][CONF_SCAN_INTERVAL], + update_method=broker.async_update, ) - hass.data[DOMAIN] = {} - hass.data[DOMAIN]["broker"] = broker = EvoBroker( - hass, client_v2, client_v1, store, config[DOMAIN] - ) + hass.data[DOMAIN] = {"broker": broker, "coordinator": coordinator} - await broker.save_auth_tokens() - await broker.async_update() # get initial state + # without a listener, _schedule_refresh() won't be invoked by _async_refresh() + coordinator.async_add_listener(lambda: None) + await coordinator.async_refresh() # get initial state hass.async_create_task( async_load_platform(hass, Platform.CLIMATE, DOMAIN, {}, config) @@ -212,10 +274,6 @@ async def load_auth_tokens(store: Store) -> tuple[dict, dict | None]: async_load_platform(hass, Platform.WATER_HEATER, DOMAIN, {}, config) ) - async_track_time_interval( - hass, broker.async_update, config[DOMAIN][CONF_SCAN_INTERVAL] - ) - setup_service_functions(hass, broker) return True @@ -272,7 +330,7 @@ async def set_zone_override(call: ServiceCall) -> None: hass.services.async_register(DOMAIN, EvoService.REFRESH_SYSTEM, force_refresh) # Enumerate which operating modes are supported by this system - modes = broker.config[SZ_ALLOWED_SYSTEM_MODES] + modes = broker.tcs.allowedSystemModes # Not all systems support "AutoWithReset": register this handler only if required if [m[SZ_SYSTEM_MODE] for m in modes if m[SZ_SYSTEM_MODE] == SZ_AUTO_WITH_RESET]: diff --git a/homeassistant/components/evohome/climate.py b/homeassistant/components/evohome/climate.py index 8b3e8a46e2c45..42ffe84121e0d 100644 --- a/homeassistant/components/evohome/climate.py +++ b/homeassistant/components/evohome/climate.py @@ -9,7 +9,6 @@ import evohomeasync2 as evo from evohomeasync2.schema.const import ( SZ_ACTIVE_FAULTS, - SZ_ALLOWED_SYSTEM_MODES, SZ_SETPOINT_STATUS, SZ_SYSTEM_ID, SZ_SYSTEM_MODE, @@ -44,7 +43,6 @@ ATTR_DURATION_UNTIL, ATTR_SYSTEM_MODE, ATTR_ZONE_TEMP, - CONF_LOCATION_IDX, DOMAIN, EVO_AUTO, EVO_AUTOECO, @@ -112,8 +110,8 @@ async def async_setup_platform( "Found the Location/Controller (%s), id=%s, name=%s (location_idx=%s)", broker.tcs.modelType, broker.tcs.systemId, - broker.tcs.location.name, - broker.params[CONF_LOCATION_IDX], + broker.loc.name, + broker.loc_idx, ) entities: list[EvoClimateEntity] = [EvoController(broker, broker.tcs)] @@ -367,7 +365,7 @@ def __init__(self, evo_broker: EvoBroker, evo_device: evo.ControlSystem) -> None self._attr_unique_id = evo_device.systemId self._attr_name = evo_device.location.name - modes = [m[SZ_SYSTEM_MODE] for m in evo_broker.config[SZ_ALLOWED_SYSTEM_MODES]] + modes = [m[SZ_SYSTEM_MODE] for m in evo_broker.tcs.allowedSystemModes] self._attr_preset_modes = [ TCS_PRESET_TO_HA[m] for m in modes if m in list(TCS_PRESET_TO_HA) ] diff --git a/homeassistant/components/evohome/coordinator.py b/homeassistant/components/evohome/coordinator.py index 6b54c5f464064..b83d2d20c6a66 100644 --- a/homeassistant/components/evohome/coordinator.py +++ b/homeassistant/components/evohome/coordinator.py @@ -5,85 +5,93 @@ from collections.abc import Awaitable from datetime import timedelta import logging -from typing import Any +from typing import TYPE_CHECKING, Any import evohomeasync as ev1 -from evohomeasync.schema import SZ_ID, SZ_SESSION_ID, SZ_TEMP +from evohomeasync.schema import SZ_ID, SZ_TEMP import evohomeasync2 as evo +from evohomeasync2.schema.const import ( + SZ_GATEWAY_ID, + SZ_GATEWAY_INFO, + SZ_LOCATION_ID, + SZ_LOCATION_INFO, + SZ_TIME_ZONE, +) -from homeassistant.const import CONF_USERNAME -from homeassistant.core import HomeAssistant from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.event import async_call_later -from homeassistant.helpers.storage import Store -from homeassistant.helpers.typing import ConfigType - -from .const import ( - ACCESS_TOKEN, - ACCESS_TOKEN_EXPIRES, - CONF_LOCATION_IDX, - DOMAIN, - GWS, - REFRESH_TOKEN, - TCS, - USER_DATA, - UTC_OFFSET, -) -from .helpers import dt_local_to_aware, handle_evo_exception + +from .const import CONF_LOCATION_IDX, DOMAIN, GWS, TCS, UTC_OFFSET +from .helpers import handle_evo_exception + +if TYPE_CHECKING: + from . import EvoSession _LOGGER = logging.getLogger(__name__.rpartition(".")[0]) class EvoBroker: - """Container for evohome client and data.""" + """Broker for evohome client broker.""" + + loc_idx: int + loc: evo.Location + loc_utc_offset: timedelta + tcs: evo.ControlSystem + + def __init__(self, sess: EvoSession) -> None: + """Initialize the evohome broker and its data structure.""" + + self._sess = sess + self.hass = sess.hass + + assert sess.client_v2 is not None # mypy + + self.client = sess.client_v2 + self.client_v1 = sess.client_v1 - def __init__( - self, - hass: HomeAssistant, - client: evo.EvohomeClient, - client_v1: ev1.EvohomeClient | None, - store: Store[dict[str, Any]], - params: ConfigType, - ) -> None: - """Initialize the evohome client and its data structure.""" - self.hass = hass - self.client = client - self.client_v1 = client_v1 - self._store = store - self.params = params - - loc_idx = params[CONF_LOCATION_IDX] - self._location: evo.Location = client.locations[loc_idx] - - assert isinstance(client.installation_info, list) # mypy - - self.config = client.installation_info[loc_idx][GWS][0][TCS][0] - self.tcs: evo.ControlSystem = self._location._gateways[0]._control_systems[0] # noqa: SLF001 - self.loc_utc_offset = timedelta(minutes=self._location.timeZone[UTC_OFFSET]) self.temps: dict[str, float | None] = {} - async def save_auth_tokens(self) -> None: - """Save access tokens and session IDs to the store for later use.""" - # evohomeasync2 uses naive/local datetimes - access_token_expires = dt_local_to_aware( - self.client.access_token_expires # type: ignore[arg-type] - ) + def validate_location(self, loc_idx: int) -> bool: + """Get the default TCS of the specified location.""" - app_storage: dict[str, Any] = { - CONF_USERNAME: self.client.username, - REFRESH_TOKEN: self.client.refresh_token, - ACCESS_TOKEN: self.client.access_token, - ACCESS_TOKEN_EXPIRES: access_token_expires.isoformat(), - } + self.loc_idx = loc_idx - if self.client_v1: - app_storage[USER_DATA] = { - SZ_SESSION_ID: self.client_v1.broker.session_id, - } # this is the schema for STORAGE_VER == 1 - else: - app_storage[USER_DATA] = {} + assert self.client.installation_info is not None # mypy - await self._store.async_save(app_storage) + try: + loc_config = self.client.installation_info[loc_idx] + except IndexError: + _LOGGER.error( + ( + "Config error: '%s' = %s, but the valid range is 0-%s. " + "Unable to continue. Fix any configuration errors and restart HA" + ), + CONF_LOCATION_IDX, + loc_idx, + len(self.client.installation_info) - 1, + ) + return False + + self.loc = self.client.locations[loc_idx] + self.loc_utc_offset = timedelta(minutes=self.loc.timeZone[UTC_OFFSET]) + self.tcs = self.loc._gateways[0]._control_systems[0] # noqa: SLF001 + + if _LOGGER.isEnabledFor(logging.DEBUG): + loc_info = { + SZ_LOCATION_ID: loc_config[SZ_LOCATION_INFO][SZ_LOCATION_ID], + SZ_TIME_ZONE: loc_config[SZ_LOCATION_INFO][SZ_TIME_ZONE], + } + gwy_info = { + SZ_GATEWAY_ID: loc_config[GWS][0][SZ_GATEWAY_INFO][SZ_GATEWAY_ID], + TCS: loc_config[GWS][0][TCS], + } + config = { + SZ_LOCATION_INFO: loc_info, + GWS: [{SZ_GATEWAY_INFO: gwy_info}], + } + _LOGGER.debug("Config = %s", config) + + return True async def call_client_api( self, @@ -108,11 +116,7 @@ async def _update_v1_api_temps(self) -> None: assert self.client_v1 is not None # mypy check - def get_session_id(client_v1: ev1.EvohomeClient) -> str | None: - user_data = client_v1.user_data if client_v1 else None - return user_data.get(SZ_SESSION_ID) if user_data else None # type: ignore[return-value] - - session_id = get_session_id(self.client_v1) + old_session_id = self._sess.session_id try: temps = await self.client_v1.get_temperatures() @@ -146,7 +150,7 @@ def get_session_id(client_v1: ev1.EvohomeClient) -> str | None: raise else: - if str(self.client_v1.location_id) != self._location.locationId: + if str(self.client_v1.location_id) != self.loc.locationId: _LOGGER.warning( "The v2 API's configured location doesn't match " "the v1 API's default location (there is more than one location), " @@ -157,8 +161,8 @@ def get_session_id(client_v1: ev1.EvohomeClient) -> str | None: self.temps = {str(i[SZ_ID]): i[SZ_TEMP] for i in temps} finally: - if self.client_v1 and session_id != self.client_v1.broker.session_id: - await self.save_auth_tokens() + if self.client_v1 and self.client_v1.broker.session_id != old_session_id: + await self._sess.save_auth_tokens() _LOGGER.debug("Temperatures = %s", self.temps) @@ -168,7 +172,7 @@ async def _update_v2_api_state(self, *args: Any) -> None: access_token = self.client.access_token # maybe receive a new token? try: - status = await self._location.refresh_status() + status = await self.loc.refresh_status() except evo.RequestFailed as err: handle_evo_exception(err) else: @@ -176,7 +180,7 @@ async def _update_v2_api_state(self, *args: Any) -> None: _LOGGER.debug("Status = %s", status) finally: if access_token != self.client.access_token: - await self.save_auth_tokens() + await self._sess.save_auth_tokens() async def async_update(self, *args: Any) -> None: """Get the latest state data of an entire Honeywell TCC Location.