diff --git a/custom_components/powercalc/config_flow.py b/custom_components/powercalc/config_flow.py index 65d6922a5..2dafb7114 100644 --- a/custom_components/powercalc/config_flow.py +++ b/custom_components/powercalc/config_flow.py @@ -1034,7 +1034,7 @@ async def _create_schema( ) -> vol.Schema: """Create sub profile schema.""" library = await ProfileLibrary.factory(self.hass) - profile = await library.get_profile(model_info) + profile = await library.get_profile(model_info, process_variables=False) sub_profiles = [selector.SelectOptionDict(value=sub_profile, label=sub_profile) for sub_profile in await profile.get_sub_profiles()] return vol.Schema( { diff --git a/custom_components/powercalc/discovery.py b/custom_components/powercalc/discovery.py index e33b7c61a..2a50c7c3a 100644 --- a/custom_components/powercalc/discovery.py +++ b/custom_components/powercalc/discovery.py @@ -51,7 +51,7 @@ async def get_power_profile_by_source_entity(hass: HomeAssistant, source_entity: model_info = await discovery_manager.extract_model_info_from_device_info(source_entity.entity_entry) if not model_info: return None - profiles = await discovery_manager.discover_entity(source_entity, model_info) + profiles = await discovery_manager.find_power_profiles(model_info, source_entity, DiscoveryBy.ENTITY) return profiles[0] if profiles else None @@ -203,7 +203,7 @@ async def find_power_profiles( power_profiles = [] for model_info in models: - profile = await get_power_profile(self.hass, {}, model_info=model_info) + profile = await get_power_profile(self.hass, {}, model_info=model_info, process_variables=False) if not profile or profile.discovery_by != discovery_type: # pragma: no cover continue if discovery_type == DiscoveryBy.ENTITY and not await self.is_entity_supported( @@ -266,6 +266,7 @@ async def is_entity_supported( {}, model_info, log_errors=log_profile_loading_errors, + process_variables=False, ) except ModelNotSupportedError: return False diff --git a/custom_components/powercalc/power_profile/factory.py b/custom_components/powercalc/power_profile/factory.py index e69b5d8d7..55b665658 100644 --- a/custom_components/powercalc/power_profile/factory.py +++ b/custom_components/powercalc/power_profile/factory.py @@ -26,6 +26,7 @@ async def get_power_profile( config: dict, model_info: ModelInfo | None = None, log_errors: bool = True, + process_variables: bool = True, ) -> PowerProfile | None: manufacturer = config.get(CONF_MANUFACTURER) model = config.get(CONF_MODEL) @@ -55,6 +56,7 @@ async def get_power_profile( ModelInfo(manufacturer or "", model or "", model_id), custom_model_directory, config.get(CONF_CUSTOM_FIELDS), + process_variables, ) except LibraryError as err: if log_errors: diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index 004335337..84224a863 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -3,7 +3,7 @@ import logging import os import re -from typing import NamedTuple, cast +from typing import Any, NamedTuple, cast from homeassistant.core import HomeAssistant from homeassistant.helpers.singleton import singleton @@ -98,6 +98,7 @@ async def get_profile( model_info: ModelInfo, custom_directory: str | None = None, variables: dict[str, str] | None = None, + process_variables: bool = True, ) -> PowerProfile: """Get a power profile for a given manufacturer and model.""" # Support multiple LUT in subdirectories @@ -106,7 +107,7 @@ async def get_profile( (model, sub_profile) = model_info.model.split("/", 1) model_info = ModelInfo(model_info.manufacturer, model, model_info.model_id) - profile = await self.create_power_profile(model_info, custom_directory, variables) + profile = await self.create_power_profile(model_info, custom_directory, variables, process_variables) if sub_profile: await profile.select_sub_profile(sub_profile) @@ -118,6 +119,7 @@ async def create_power_profile( model_info: ModelInfo, custom_directory: str | None = None, variables: dict[str, str] | None = None, + process_variables: bool = True, ) -> PowerProfile: """Create a power profile object from the model JSON data.""" @@ -128,14 +130,29 @@ async def create_power_profile( model_info = next(iter(models)) json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory) - if variables: - json_data = cast(dict, replace_placeholders(json_data, variables)) + if json_data.get("fields") and process_variables: + self.validate_variables(json_data, variables or {}) + json_data = cast(dict, replace_placeholders(json_data, variables or {})) if linked_profile := json_data.get("linked_profile", json_data.get("linked_lut")): linked_manufacturer, linked_model = linked_profile.split("/") _, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory) return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data) + @staticmethod + def validate_variables(json_data: dict[str, Any], variables: dict[str, str]) -> None: + fields = json_data.get("fields", {}).keys() + + # Check if all variables are valid for the model + for variable in variables: + if variable not in fields: + raise LibraryError(f"Variable {variable} is not valid for this model") + + # Check if all fields have corresponding variables + missing_fields = [field for field in fields if field not in variables] + if missing_fields: + raise LibraryError(f"Missing variables for fields: {', '.join(missing_fields)}") + async def find_manufacturers(self, manufacturer: str) -> set[str]: """Resolve the manufacturer, either from the model info or by loading it.""" return await self._loader.find_manufacturers(manufacturer) diff --git a/custom_components/powercalc/sensor.py b/custom_components/powercalc/sensor.py index f18d90129..97df4cebd 100644 --- a/custom_components/powercalc/sensor.py +++ b/custom_components/powercalc/sensor.py @@ -52,6 +52,7 @@ CONF_CREATE_ENERGY_SENSOR, CONF_CREATE_GROUP, CONF_CREATE_UTILITY_METERS, + CONF_CUSTOM_FIELDS, CONF_CUSTOM_MODEL_DIRECTORY, CONF_DAILY_FIXED_ENERGY, CONF_DELAY, @@ -218,6 +219,7 @@ ), vol.Optional(CONF_UNAVAILABLE_POWER): vol.Coerce(float), vol.Optional(CONF_COMPOSITE): COMPOSITE_SCHEMA, + vol.Optional(CONF_CUSTOM_FIELDS): vol.Schema({cv.string: cv.string}), } diff --git a/tests/power_profile/device_types/test_custom_fields.py b/tests/power_profile/device_types/test_custom_fields.py new file mode 100644 index 000000000..4505932f2 --- /dev/null +++ b/tests/power_profile/device_types/test_custom_fields.py @@ -0,0 +1,33 @@ +import logging + +import pytest +from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, STATE_ON +from homeassistant.core import HomeAssistant + +from custom_components.powercalc.const import CONF_CUSTOM_FIELDS, CONF_MANUFACTURER, CONF_MODEL, DUMMY_ENTITY_ID +from tests.common import get_test_config_dir, run_powercalc_setup + + +async def test_custom_field_variables_from_yaml_config(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.ERROR) + hass.config.config_dir = get_test_config_dir() + + hass.states.async_set("sensor.test", STATE_ON) + await hass.async_block_till_done() + + await run_powercalc_setup( + hass, + { + CONF_ENTITY_ID: DUMMY_ENTITY_ID, + CONF_NAME: "Test", + CONF_MANUFACTURER: "test", + CONF_MODEL: "custom-fields", + CONF_CUSTOM_FIELDS: { + "some_entity": "sensor.test", + }, + }, + ) + + assert not caplog.records + + assert hass.states.get("sensor.test_power").state == "20.00"