diff --git a/.github/scripts/profile_library/update-authors.py b/.github/scripts/profile_library/update-authors.py deleted file mode 100644 index 79560ca4f..000000000 --- a/.github/scripts/profile_library/update-authors.py +++ /dev/null @@ -1,97 +0,0 @@ -import glob -import json -import os -import subprocess -import sys - - -def run_git_command(command): - """Run a git command and return the output.""" - result = subprocess.run(command, shell=True, capture_output=True, text=True) - result.check_returncode() # Raise an error if the command fails - return result.stdout.strip() - - -def get_commits_affected_directory(directory: str) -> list: - """Get a list of commits that affected the given directory, including renames.""" - command = f"git log --follow --format='%H' -- '{directory}'" - commits = run_git_command(command) - return commits.splitlines() - - -def get_commit_author(commit_hash: str) -> str: - """Get the author of a given commit.""" - command = f"git show -s --format='%an <%ae>' {commit_hash}" - author = run_git_command(command) - return author - - -def find_first_commit_author(file: str, check_paths: bool = True) -> str | None: - """Find the first commit that affected the directory and return the author's name.""" - commits = get_commits_affected_directory(file) - for commit in reversed(commits): # Process commits from the oldest to newest - command = f"git diff-tree --no-commit-id --name-only -r {commit}" - if not check_paths: - return get_commit_author(commit) - - affected_files = run_git_command(command) - paths = [file.replace("profile_library", "custom_components/powercalc/data"), file.replace("profile_library", "data"), file] - if any(path in affected_files.splitlines() for path in paths): - author = get_commit_author(commit) - return author - return None - - -def process_model_json_files(root_dir): - # Find all model.json files in the directory tree - model_json_files = glob.glob(os.path.join(root_dir, "**", "model.json"), recursive=True) - - for model_json_file in model_json_files: - # Skip sub profiles - if model_json_file.count("/") != 3: - continue - - author = read_author_from_file(os.path.abspath(model_json_file)) - if author: - print(f"Skipping {model_json_file}, author already set to {author}") - continue - - author = find_first_commit_author(model_json_file) - if author is None: - print(f"Skipping {model_json_file}, author not found") - continue - - write_author_to_file(os.path.abspath(model_json_file), author) - print(f"Updated {model_json_file} with author {author}") - - -def read_author_from_file(file_path: str) -> str | None: - """Read the author from the model.json file.""" - with open(file_path) as file: - json_data = json.load(file) - - return json_data.get("author") - - -def write_author_to_file(file_path: str, author: str) -> None: - """Write the author to the model.json file.""" - # Read the existing content - with open(file_path) as file: - json_data = json.load(file) - - json_data["author"] = author - - with open(file_path, "w") as file: - json.dump(json_data, file, indent=2) - - -def main(): - try: - process_model_json_files("profile_library") - except subprocess.CalledProcessError as e: - print(f"Error running git command: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/.github/scripts/profile_library/update-library-json.py b/.github/scripts/profile_library/update-library-json.py deleted file mode 100644 index 284d49615..000000000 --- a/.github/scripts/profile_library/update-library-json.py +++ /dev/null @@ -1,135 +0,0 @@ -from __future__ import annotations - -import glob -import json -import os -import sys -from datetime import datetime -from pathlib import Path - -import git - -sys.path.insert( - 1, - os.path.abspath( - os.path.join(Path(__file__), "../../../../custom_components/powercalc"), - ), -) - -PROJECT_ROOT = os.path.realpath(os.path.join(os.path.abspath(__file__), "../../../../")) -DATA_DIR = f"{PROJECT_ROOT}/profile_library" - - -def generate_library_json(model_listing: list[dict]) -> None: - manufacturers: dict[str, dict] = {} - for model in model_listing: - manufacturer_name = model.get("manufacturer") - manufacturer = manufacturers.get(manufacturer_name) - if not manufacturer: - manufacturer = { - **get_manufacturer_json(manufacturer_name), - "models": [], - "device_types": [], - } - manufacturers[manufacturer_name] = manufacturer - - device_type = model.get("device_type") - if device_type not in manufacturer["device_types"]: - manufacturer["device_types"].append(device_type) - - key_mapping = { - "model": "id", - "name": "name", - "device_type": "device_type", - "aliases": "aliases", - "updated_at": "updated_at", - "color_modes": "color_modes", - } - - # Create a new dictionary with updated keys - mapped_dict = {key_mapping.get(key, key): value for key, value in model.items()} - manufacturer["models"].append({key: mapped_dict[key] for key in key_mapping.values() if key in mapped_dict}) - - json_data = { - "manufacturers": list(manufacturers.values()), - } - - with open( - os.path.join(DATA_DIR, "library.json"), - "w", - ) as json_file: - json_file.write(json.dumps(json_data)) - - print("Generated library.json") - - -def get_manufacturer_json(manufacturer: str) -> dict: - json_path = os.path.join(DATA_DIR, manufacturer, "manufacturer.json") - try: - with open(json_path) as json_file: - return json.load(json_file) - except FileNotFoundError: - default_json = {"name": manufacturer, "aliases": []} - with open(json_path, "w", encoding="utf-8") as json_file: - json.dump(default_json, json_file, ensure_ascii=False, indent=4) - git.Repo(PROJECT_ROOT).git.add(json_path) - print(f"Added {json_path}") - return default_json - - -def get_model_list() -> list[dict]: - """Get a listing of all available powercalc models""" - models = [] - for json_path in glob.glob( - f"{DATA_DIR}/*/*/model.json", - recursive=True, - ): - with open(json_path) as json_file: - model_directory = os.path.dirname(json_path) - model_data: dict = json.load(json_file) - color_modes = get_color_modes(model_directory, DATA_DIR, model_data) - updated_at = get_last_commit_time(model_directory).isoformat() - manufacturer = os.path.basename(os.path.dirname(model_directory)) - - model_data.update( - { - "model": os.path.basename(model_directory), - "manufacturer": manufacturer, - "directory": model_directory, - "updated_at": updated_at, - }, - ) - if "device_type" not in model_data: - model_data["device_type"] = "light" - - if color_modes: - model_data["color_modes"] = list(color_modes) - models.append(model_data) - - return models - - -def get_color_modes(model_directory: str, data_dir: str, model_data: dict) -> set: - if "linked_profile" in model_data: - model_directory = os.path.join(data_dir, model_data["linked_profile"]) - - color_modes = set() - for path in glob.glob(f"{model_directory}/**/*.csv.gz", recursive=True): - filename = os.path.basename(path) - index = filename.index(".") - color_mode = filename[:index] - color_modes.add(color_mode) - return color_modes - - -def get_last_commit_time(directory: str) -> datetime: - repo = git.Repo(directory, search_parent_directories=True) - commits = list(repo.iter_commits(paths=directory)) - if commits: - last_commit = commits[0] - return last_commit.committed_datetime - return datetime.fromtimestamp(0) - - -model_list = get_model_list() -generate_library_json(model_list) diff --git a/.github/scripts/profile_library/update-library.py b/.github/scripts/profile_library/update-library.py new file mode 100644 index 000000000..069a1d28c --- /dev/null +++ b/.github/scripts/profile_library/update-library.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +import argparse +import glob +import json +import os +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +import git + +sys.path.insert( + 1, + os.path.abspath( + os.path.join(Path(__file__), "../../../../custom_components/powercalc"), + ), +) + +PROJECT_ROOT = os.path.realpath(os.path.join(os.path.abspath(__file__), "../../../../")) +DATA_DIR = f"{PROJECT_ROOT}/profile_library" + + +def generate_library_json(model_listing: list[dict]) -> None: + manufacturers: dict[str, dict] = {} + for model in model_listing: + manufacturer_name = model.get("manufacturer") + manufacturer = manufacturers.get(manufacturer_name) + if not manufacturer: + manufacturer = { + **get_manufacturer_json(manufacturer_name), + "models": [], + "device_types": [], + } + manufacturers[manufacturer_name] = manufacturer + + device_type = model.get("device_type") + if device_type not in manufacturer["device_types"]: + manufacturer["device_types"].append(device_type) + + key_mapping = { + "model": "id", + "name": "name", + "device_type": "device_type", + "aliases": "aliases", + "updated_at": "updated_at", + "color_modes": "color_modes", + } + + # Create a new dictionary with updated keys + mapped_dict = {key_mapping.get(key, key): value for key, value in model.items()} + manufacturer["models"].append({key: mapped_dict[key] for key in key_mapping.values() if key in mapped_dict}) + + json_data = { + "manufacturers": list(manufacturers.values()), + } + + with open( + os.path.join(DATA_DIR, "library.json"), + "w", + ) as json_file: + json_file.write(json.dumps(json_data)) + + print("Generated library.json") + + +def update_authors(model_listing: list[dict]) -> None: + for model in model_listing: + author = model.get("author") + model_json_path = model.get("full_path") + if author: + #print(f"Skipping {model_json_path}, author already set to {author}") + continue + + author = find_first_commit_author(model_json_path) + if author is None: + print(f"Skipping {model_json_path}, author not found") + continue + + write_author_to_file(model_json_path, author) + print(f"Updated {model_json_path} with author {author}") + +def update_translations(model_listing: list[dict]) -> None: + data_translations: dict[str, str] = {} + description_translations: dict[str, str] = {} + for model in model_listing: + custom_fields = model.get("fields") + if not custom_fields: + #print(f"Skipping {model_json_path}, no custom fields found") + continue + + for key, field_data in custom_fields.items(): + data_translations[key] = field_data.get("name") + description_translations[key] = field_data.get("description") + + if not data_translations: + print(f"No translations found") + return + + translation_file = os.path.join(PROJECT_ROOT, "custom_components/powercalc/translations/en.json") + with open(translation_file) as file: + json_data = json.load(file) + step = "library_custom_fields" + if not step in json_data["config"]["step"]: + json_data["config"]["step"][step] = { + "data": {}, + "data_description": {}, + } + deep_update(json_data["config"]["step"][step]["data"], data_translations) + deep_update(json_data["config"]["step"][step]["data_description"], description_translations) + + with open(translation_file, "w") as file: + json.dump(json_data, file, indent=2) + + +def deep_update(target: dict, updates: dict) -> None: + """ + Recursively updates a dictionary with another dictionary, + only adding keys that are missing. + """ + for key, value in updates.items(): + if isinstance(value, dict) and key in target and isinstance(target[key], dict): + deep_update(target[key], value) + elif key not in target: + target[key] = value + + +def get_manufacturer_json(manufacturer: str) -> dict: + json_path = os.path.join(DATA_DIR, manufacturer, "manufacturer.json") + try: + with open(json_path) as json_file: + return json.load(json_file) + except FileNotFoundError: + default_json = {"name": manufacturer, "aliases": []} + with open(json_path, "w", encoding="utf-8") as json_file: + json.dump(default_json, json_file, ensure_ascii=False, indent=4) + git.Repo(PROJECT_ROOT).git.add(json_path) + print(f"Added {json_path}") + return default_json + + +def get_model_list() -> list[dict]: + """Get a listing of all available powercalc models""" + models = [] + for json_path in glob.glob( + f"{DATA_DIR}/*/*/model.json", + recursive=True, + ): + with open(json_path) as json_file: + model_directory = os.path.dirname(json_path) + model_data: dict = json.load(json_file) + color_modes = get_color_modes(model_directory, DATA_DIR, model_data) + updated_at = get_last_commit_time(model_directory).isoformat() + manufacturer = os.path.basename(os.path.dirname(model_directory)) + + model_data.update( + { + "model": os.path.basename(model_directory), + "manufacturer": manufacturer, + "directory": model_directory, + "updated_at": updated_at, + "full_path": json_path, + }, + ) + if "device_type" not in model_data: + model_data["device_type"] = "light" + + if color_modes: + model_data["color_modes"] = list(color_modes) + models.append(model_data) + + return models + + +def get_color_modes(model_directory: str, data_dir: str, model_data: dict) -> set: + if "linked_profile" in model_data: + model_directory = os.path.join(data_dir, model_data["linked_profile"]) + + color_modes = set() + for path in glob.glob(f"{model_directory}/**/*.csv.gz", recursive=True): + filename = os.path.basename(path) + index = filename.index(".") + color_mode = filename[:index] + color_modes.add(color_mode) + return color_modes + + +def get_last_commit_time(directory: str) -> datetime: + try: + # Use subprocess to run the git command + result = subprocess.run( + ["git", "log", "-1", "--format=%ct", "--", directory], + cwd=directory, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + out = result.stdout.strip() + if not out: + return datetime.fromtimestamp(0) + timestamp = int(out) + return datetime.fromtimestamp(timestamp) + except subprocess.CalledProcessError: + # Handle case where there are no commits or Git command fails + return datetime.fromtimestamp(0) + +def run_git_command(command): + """Run a git command and return the output.""" + result = subprocess.run(command, shell=True, capture_output=True, text=True) + result.check_returncode() # Raise an error if the command fails + return result.stdout.strip() + + +def get_commits_affected_directory(directory: str) -> list: + """Get a list of commits that affected the given directory, including renames.""" + command = f"git log --follow --format='%H' -- '{directory}'" + commits = run_git_command(command) + return commits.splitlines() + + +def get_commit_author(commit_hash: str) -> str: + """Get the author of a given commit.""" + command = f"git show -s --format='%an <%ae>' {commit_hash}" + author = run_git_command(command) + return author + + +def find_first_commit_author(file: str, check_paths: bool = True) -> str | None: + """Find the first commit that affected the directory and return the author's name.""" + commits = get_commits_affected_directory(file) + for commit in reversed(commits): # Process commits from the oldest to newest + command = f"git diff-tree --no-commit-id --name-only -r {commit}" + if not check_paths: + return get_commit_author(commit) + + affected_files = run_git_command(command) + file = file.replace(PROJECT_ROOT, "").lstrip("/") + paths = [file.replace("profile_library", "custom_components/powercalc/data"), file.replace("profile_library", "data"), file] + if any(path in affected_files.splitlines() for path in paths): + author = get_commit_author(commit) + return author + return None + +def read_author_from_file(file_path: str) -> str | None: + """Read the author from the model.json file.""" + with open(file_path) as file: + json_data = json.load(file) + + return json_data.get("author") + + +def write_author_to_file(file_path: str, author: str) -> None: + """Write the author to the model.json file.""" + # Read the existing content + with open(file_path) as file: + json_data = json.load(file) + + json_data["author"] = author + + with open(file_path, "w") as file: + json.dump(json_data, file, indent=2) + +def main(): + parser = argparse.ArgumentParser(description="Process profiles JSON files and perform updates.") + parser.add_argument("--authors", action="store_true", help="Update authors") + parser.add_argument("--library-json", action="store_true", help="Generate library.json") + parser.add_argument("--translations", action="store_true", help="Update translations") + parser.add_argument("--all", action="store_true", help="Run all operations (default if no arguments)") + + args = parser.parse_args() + + # Determine whether to run all operations + run_all = not any([args.authors, args.library_json, args.translations]) or args.all + + print("Start reading profiles JSON files..") + model_list = get_model_list() + print(f"Found {len(model_list)} profiles") + + if run_all or args.library_json: + print("Generating library.json..") + generate_library_json(model_list) + + if run_all or args.authors: + print("Updating authors..") + update_authors(model_list) + + if run_all or args.translations: + print("Updating translations..") + update_translations(model_list) + +if __name__ == "__main__": + main() diff --git a/.github/workflows/update-profile-library.yml b/.github/workflows/update-profile-library.yml index 6c35e96d0..d1554551e 100644 --- a/.github/workflows/update-profile-library.yml +++ b/.github/workflows/update-profile-library.yml @@ -29,12 +29,9 @@ jobs: python -m pip install -r ${{ github.workspace }}/.github/scripts/profile_library/requirements.txt - name: Pull again run: git pull || true - - name: Generate library.json + - name: Update library.json, authors and translations run: | - python3 ${{ github.workspace }}/.github/scripts/profile_library/update-library-json.py - - name: Update authors - run: | - python3 ${{ github.workspace }}/.github/scripts/profile_library/update-authors.py + python3 ${{ github.workspace }}/.github/scripts/profile_library/update-library.py - uses: EndBug/add-and-commit@v9 if: github.ref == 'refs/heads/master' with: diff --git a/custom_components/powercalc/config_flow.py b/custom_components/powercalc/config_flow.py index 9c8b93b7d..b4f8d1d27 100644 --- a/custom_components/powercalc/config_flow.py +++ b/custom_components/powercalc/config_flow.py @@ -104,6 +104,7 @@ CONF_UTILITY_METER_TYPES, CONF_VALUE, CONF_VALUE_TEMPLATE, + CONF_VARIABLES, DISCOVERY_POWER_PROFILES, DISCOVERY_SOURCE_ENTITY, DOMAIN, @@ -120,6 +121,7 @@ ) from .discovery import get_power_profile_by_source_entity from .errors import ModelNotSupportedError, StrategyConfigurationError +from .flow_helper.dynamic_field_builder import build_dynamic_field_schema from .group_include.include import find_entities from .power_profile.factory import get_power_profile from .power_profile.library import ModelInfo, ProfileLibrary @@ -144,6 +146,7 @@ class Step(StrEnum): GROUP_TRACKED_UNTRACKED_MANUAL = "group_tracked_untracked_manual" LIBRARY = "library" POST_LIBRARY = "post_library" + LIBRARY_CUSTOM_FIELDS = "library_custom_fields" LIBRARY_MULTI_PROFILE = "library_multi_profile" LIBRARY_OPTIONS = "library_options" VIRTUAL_POWER = "virtual_power" @@ -932,7 +935,7 @@ async def _validate(user_input: dict[str, Any]) -> dict[str, str]: ), ) self.selected_profile = profile - if self.selected_profile and not await self.selected_profile.has_sub_profiles: + if self.selected_profile and not await self.selected_profile.needs_user_configuration: await self.validate_strategy_config() return user_input @@ -974,27 +977,45 @@ async def async_step_post_library( Handles the logic after the user either selected manufacturer/model himself or confirmed autodiscovered. Forwards to the next step in the flow. """ - if self.selected_profile and await self.selected_profile.has_sub_profiles and not self.selected_profile.sub_profile_select: + if not self.selected_profile: + return self.async_abort(reason="model_not_supported") # pragma: no cover + + if self.selected_profile.has_custom_fields and not self.sensor_config.get(CONF_VARIABLES): + return await self.async_step_library_custom_fields() + + if await self.selected_profile.has_sub_profiles and not self.selected_profile.sub_profile_select: return await self.async_step_sub_profile() - if ( - self.selected_profile - and self.selected_profile.device_type == DeviceType.SMART_SWITCH - and self.selected_profile.calculation_strategy == CalculationStrategy.FIXED - ): + if self.selected_profile.device_type == DeviceType.SMART_SWITCH and self.selected_profile.calculation_strategy == CalculationStrategy.FIXED: return await self.async_step_smart_switch() - if self.selected_profile and self.selected_profile.needs_fixed_config: # pragma: no cover + if self.selected_profile.needs_fixed_config: # pragma: no cover return await self.async_step_fixed() - if self.selected_profile and self.selected_profile.needs_linear_config: + if self.selected_profile.needs_linear_config: return await self.async_step_linear() - if self.selected_profile and self.selected_profile.calculation_strategy == CalculationStrategy.MULTI_SWITCH: + if self.selected_profile.calculation_strategy == CalculationStrategy.MULTI_SWITCH: return await self.async_step_multi_switch() return await self.async_step_power_advanced() + async def async_step_library_custom_fields(self, user_input: dict[str, Any] | None = None) -> FlowResult: + """Handle the flow for custom fields.""" + + async def _process_user_input(user_input: dict[str, Any]) -> dict[str, Any]: + return {CONF_VARIABLES: user_input} + + return await self.handle_form_step( + PowercalcFormStep( + step=Step.LIBRARY_CUSTOM_FIELDS, + schema=build_dynamic_field_schema(self.selected_profile), # type: ignore + next_step=Step.POST_LIBRARY, + validate_user_input=_process_user_input, + ), + user_input, + ) + async def async_step_sub_profile( self, user_input: dict[str, Any] | None = None, @@ -1009,7 +1030,7 @@ async def _create_schema( ) -> vol.Schema: """Create sub profile schema.""" library = await ProfileLibrary.factory(self.hass) - profile = await library.get_profile(model_info) + profile = await library.get_profile(model_info, process_variables=False) sub_profiles = [selector.SelectOptionDict(value=sub_profile, label=sub_profile) for sub_profile in await profile.get_sub_profiles()] return vol.Schema( { diff --git a/custom_components/powercalc/const.py b/custom_components/powercalc/const.py index 311da45ee..77829f50e 100644 --- a/custom_components/powercalc/const.py +++ b/custom_components/powercalc/const.py @@ -46,6 +46,7 @@ CONF_CREATE_ENERGY_SENSOR = "create_energy_sensor" CONF_CREATE_ENERGY_SENSORS = "create_energy_sensors" CONF_CREATE_UTILITY_METERS = "create_utility_meters" +CONF_VARIABLES = "variables" CONF_DAILY_FIXED_ENERGY = "daily_fixed_energy" CONF_DELAY = "delay" CONF_DISABLE_LIBRARY_DOWNLOAD = "disable_library_download" diff --git a/custom_components/powercalc/discovery.py b/custom_components/powercalc/discovery.py index e33b7c61a..90fa6af4a 100644 --- a/custom_components/powercalc/discovery.py +++ b/custom_components/powercalc/discovery.py @@ -32,7 +32,6 @@ MANUFACTURER_WLED, CalculationStrategy, ) -from .errors import ModelNotSupportedError from .group_include.filter import CategoryFilter, CompositeFilter, FilterOperator, LambdaFilter, NotFilter, get_filtered_entity_list from .helpers import get_or_create_unique_id from .power_profile.factory import get_power_profile @@ -51,7 +50,7 @@ async def get_power_profile_by_source_entity(hass: HomeAssistant, source_entity: model_info = await discovery_manager.extract_model_info_from_device_info(source_entity.entity_entry) if not model_info: return None - profiles = await discovery_manager.discover_entity(source_entity, model_info) + profiles = await discovery_manager.find_power_profiles(model_info, source_entity, DiscoveryBy.ENTITY) return profiles[0] if profiles else None @@ -203,13 +202,11 @@ async def find_power_profiles( power_profiles = [] for model_info in models: - profile = await get_power_profile(self.hass, {}, model_info=model_info) + profile = await get_power_profile(self.hass, {}, model_info=model_info, process_variables=False) if not profile or profile.discovery_by != discovery_type: # pragma: no cover continue - if discovery_type == DiscoveryBy.ENTITY and not await self.is_entity_supported( + if discovery_type == DiscoveryBy.ENTITY and not profile.is_entity_domain_supported( source_entity.entity_entry, # type: ignore[arg-type] - model_info, - profile, ): continue power_profiles.append(profile) @@ -247,31 +244,6 @@ def is_wled_light(model_info: ModelInfo, entity_entry: er.RegistryEntry) -> bool and not re.search("master|segment", str(entity_entry.entity_id), flags=re.IGNORECASE) ) - async def is_entity_supported( - self, - entity_entry: er.RegistryEntry, - model_info: ModelInfo | None = None, - power_profile: PowerProfile | None = None, - log_profile_loading_errors: bool = True, - ) -> bool: - if not model_info: - model_info = await self.extract_model_info_from_device_info(entity_entry) - if not model_info or not model_info.manufacturer or not model_info.model: - return False - - if not power_profile: - try: - power_profile = await get_power_profile( - self.hass, - {}, - model_info, - log_errors=log_profile_loading_errors, - ) - except ModelNotSupportedError: - return False - - return power_profile.is_entity_domain_supported(entity_entry) if power_profile else False - async def get_entities(self) -> list[er.RegistryEntry]: """Get all entities from entity registry which qualifies for discovery.""" diff --git a/custom_components/powercalc/flow_helper/__init__.py b/custom_components/powercalc/flow_helper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/custom_components/powercalc/flow_helper/dynamic_field_builder.py b/custom_components/powercalc/flow_helper/dynamic_field_builder.py new file mode 100644 index 000000000..2b483b6aa --- /dev/null +++ b/custom_components/powercalc/flow_helper/dynamic_field_builder.py @@ -0,0 +1,15 @@ +import voluptuous as vol +from homeassistant.helpers.selector import selector + +from custom_components.powercalc.power_profile.power_profile import PowerProfile + + +def build_dynamic_field_schema(profile: PowerProfile) -> vol.Schema: + schema = {} + for field in profile.custom_fields: + field_description = field.description + if not field_description: + field_description = field.label + key = vol.Required(field.key, description=field_description) + schema[key] = selector(field.selector) + return vol.Schema(schema) diff --git a/custom_components/powercalc/group_include/include.py b/custom_components/powercalc/group_include/include.py index 8a9c5a231..f51bd232c 100644 --- a/custom_components/powercalc/group_include/include.py +++ b/custom_components/powercalc/group_include/include.py @@ -5,13 +5,13 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity import Entity -from custom_components.powercalc import DiscoveryManager +from custom_components.powercalc.common import create_source_entity from custom_components.powercalc.const import ( DATA_CONFIGURED_ENTITIES, - DATA_DISCOVERY_MANAGER, DATA_ENTITIES, DOMAIN, ) +from custom_components.powercalc.discovery import get_power_profile_by_source_entity from custom_components.powercalc.sensors.energy import RealEnergySensor from custom_components.powercalc.sensors.power import RealPowerSensor @@ -29,7 +29,6 @@ async def find_entities( Based on given entity filter fetch all power and energy sensors from the HA instance """ domain_data = hass.data.get(DOMAIN, {}) - discovery_manager: DiscoveryManager = domain_data.get(DATA_DISCOVERY_MANAGER) resolved_entities: list[Entity] = [] discoverable_entities: list[str] = [] @@ -55,7 +54,11 @@ async def find_entities( elif device_class == SensorDeviceClass.ENERGY and source_entity.platform != "utility_meter": resolved_entities.append(RealEnergySensor(source_entity.entity_id)) - if source_entity and await discovery_manager.is_entity_supported(source_entity, None, log_profile_loading_errors=False): + power_profile = await get_power_profile_by_source_entity( + hass, + await create_source_entity(source_entity.entity_id, hass), + ) + if power_profile and not await power_profile.needs_user_configuration and power_profile.is_entity_domain_supported(source_entity): discoverable_entities.append(source_entity.entity_id) return resolved_entities, discoverable_entities diff --git a/custom_components/powercalc/helpers.py b/custom_components/powercalc/helpers.py index e6621ca64..b9998b02b 100644 --- a/custom_components/powercalc/helpers.py +++ b/custom_components/powercalc/helpers.py @@ -1,6 +1,7 @@ import decimal import logging import os.path +import re import uuid from collections.abc import Callable, Coroutine from decimal import Decimal @@ -115,3 +116,21 @@ async def wrapper(*args: Any, **kwargs: Any) -> R: # noqa: ANN401 return result return wrapper + + +def replace_placeholders(data: list | str | dict[str, Any], replacements: dict[str, str]) -> list | str | dict[str, Any]: + """Replace placeholders in a dictionary with values from a replacement dictionary.""" + if isinstance(data, dict): + for key, value in data.items(): + data[key] = replace_placeholders(value, replacements) + elif isinstance(data, list): + for i in range(len(data)): + data[i] = replace_placeholders(data[i], replacements) + elif isinstance(data, str): + # Adjust regex to match [[variable]] + matches = re.findall(r"\[\[\s*(\w+)\s*\]\]", data) + for match in matches: + if match in replacements: + # Replace [[variable]] with its value + data = data.replace(f"[[{match}]]", replacements[match]) + return data diff --git a/custom_components/powercalc/power_profile/factory.py b/custom_components/powercalc/power_profile/factory.py index 2263ce0fb..22023e501 100644 --- a/custom_components/powercalc/power_profile/factory.py +++ b/custom_components/powercalc/power_profile/factory.py @@ -3,12 +3,14 @@ import logging import os +from homeassistant.const import CONF_ENTITY_ID from homeassistant.core import HomeAssistant from custom_components.powercalc.const import ( CONF_CUSTOM_MODEL_DIRECTORY, CONF_MANUFACTURER, CONF_MODEL, + CONF_VARIABLES, MANUFACTURER_WLED, ) from custom_components.powercalc.errors import ModelNotSupportedError @@ -25,6 +27,7 @@ async def get_power_profile( config: dict, model_info: ModelInfo | None = None, log_errors: bool = True, + process_variables: bool = True, ) -> PowerProfile | None: manufacturer = config.get(CONF_MANUFACTURER) model = config.get(CONF_MODEL) @@ -50,9 +53,15 @@ async def get_power_profile( library = await ProfileLibrary.factory(hass) try: + variables = config.get(CONF_VARIABLES, {}).copy() + if CONF_ENTITY_ID in config: + variables["entity"] = config[CONF_ENTITY_ID] + profile = await library.get_profile( ModelInfo(manufacturer or "", model or "", model_id), custom_model_directory, + variables, + process_variables, ) except LibraryError as err: if log_errors: diff --git a/custom_components/powercalc/power_profile/library.py b/custom_components/powercalc/power_profile/library.py index e5f16e768..fae2dbfd4 100644 --- a/custom_components/powercalc/power_profile/library.py +++ b/custom_components/powercalc/power_profile/library.py @@ -3,12 +3,13 @@ import logging import os import re -from typing import NamedTuple +from typing import Any, NamedTuple, cast from homeassistant.core import HomeAssistant from homeassistant.helpers.singleton import singleton from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DOMAIN, DOMAIN_CONFIG +from custom_components.powercalc.helpers import replace_placeholders from .error import LibraryError from .loader.composite import CompositeLoader @@ -96,6 +97,8 @@ async def get_profile( self, model_info: ModelInfo, custom_directory: str | None = None, + variables: dict[str, str] | None = None, + process_variables: bool = True, ) -> PowerProfile: """Get a power profile for a given manufacturer and model.""" # Support multiple LUT in subdirectories @@ -104,7 +107,7 @@ async def get_profile( (model, sub_profile) = model_info.model.split("/", 1) model_info = ModelInfo(model_info.manufacturer, model, model_info.model_id) - profile = await self.create_power_profile(model_info, custom_directory) + profile = await self.create_power_profile(model_info, custom_directory, variables, process_variables) if sub_profile: await profile.select_sub_profile(sub_profile) @@ -115,6 +118,8 @@ async def create_power_profile( self, model_info: ModelInfo, custom_directory: str | None = None, + variables: dict[str, str] | None = None, + process_variables: bool = True, ) -> PowerProfile: """Create a power profile object from the model JSON data.""" @@ -125,12 +130,31 @@ async def create_power_profile( model_info = next(iter(models)) json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory) + if process_variables: + if json_data.get("fields"): # When custom fields in profile are defined, make sure all variables are passed + self.validate_variables(json_data, variables or {}) + json_data = cast(dict, replace_placeholders(json_data, variables or {})) + if linked_profile := json_data.get("linked_profile", json_data.get("linked_lut")): linked_manufacturer, linked_model = linked_profile.split("/") _, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory) return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data) + @staticmethod + def validate_variables(json_data: dict[str, Any], variables: dict[str, str]) -> None: + fields = json_data.get("fields", {}).keys() + + # Check if all variables are valid for the model + for variable in variables: + if variable not in fields and variable != "entity": + raise LibraryError(f"Variable {variable} is not valid for this model") + + # Check if all fields have corresponding variables + missing_fields = [field for field in fields if field not in variables] + if missing_fields: + raise LibraryError(f"Missing variables for fields: {', '.join(missing_fields)}") + async def find_manufacturers(self, manufacturer: str) -> set[str]: """Resolve the manufacturer, either from the model info or by loading it.""" return await self._loader.find_manufacturers(manufacturer) diff --git a/custom_components/powercalc/power_profile/power_profile.py b/custom_components/powercalc/power_profile/power_profile.py index 0928acc41..860e1060c 100644 --- a/custom_components/powercalc/power_profile/power_profile.py +++ b/custom_components/powercalc/power_profile/power_profile.py @@ -5,8 +5,9 @@ import os import re from collections import defaultdict +from dataclasses import dataclass from enum import StrEnum -from typing import NamedTuple, Protocol, cast +from typing import Any, NamedTuple, Protocol, cast from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.camera import DOMAIN as CAMERA_DOMAIN @@ -59,6 +60,14 @@ class SubProfileMatcherType(StrEnum): INTEGRATION = "integration" +@dataclass(frozen=True) +class CustomField: + key: str + label: str + selector: dict[str, Any] + description: str | None = None + + DEVICE_TYPE_DOMAIN = { DeviceType.CAMERA: CAMERA_DOMAIN, DeviceType.COVER: COVER_DOMAIN, @@ -177,7 +186,7 @@ def linear_config(self) -> ConfigType | None: @property def multi_switch_config(self) -> ConfigType | None: - """Get configuration to set up linear strategy.""" + """Get configuration to set up multi_switch strategy.""" return self.get_strategy_config(CalculationStrategy.MULTI_SWITCH) @property @@ -258,6 +267,16 @@ def only_self_usage(self) -> bool: """Whether this profile only provides self usage.""" return bool(self._json_data.get("only_self_usage", False)) + @property + def has_custom_fields(self) -> bool: + """Whether this profile has custom fields.""" + return bool(self._json_data.get("fields")) + + @property + def custom_fields(self) -> list[CustomField]: + """Get the custom fields of this profile.""" + return [CustomField(key=key, **field) for key, field in self._json_data.get("fields", {}).items()] + @property def config_flow_discovery_remarks(self) -> str | None: """Get remarks to show at the config flow discovery step.""" @@ -329,6 +348,20 @@ def _load_json() -> None: self.sub_profile = sub_profile + @property + async def needs_user_configuration(self) -> bool: + """Check whether this profile needs user configuration.""" + if self.calculation_strategy == CalculationStrategy.MULTI_SWITCH: + return True + + if self.needs_fixed_config or self.needs_linear_config: + return True + + if self.has_custom_fields: + return True + + return await self.has_sub_profiles and not self.sub_profile_select + def is_entity_domain_supported(self, entity_entry: RegistryEntry) -> bool: """Check whether this power profile supports a given entity domain.""" if self.device_type is None: diff --git a/custom_components/powercalc/sensor.py b/custom_components/powercalc/sensor.py index f18d90129..9cc8438b5 100644 --- a/custom_components/powercalc/sensor.py +++ b/custom_components/powercalc/sensor.py @@ -97,6 +97,7 @@ CONF_UTILITY_METER_TYPES, CONF_VALUE, CONF_VALUE_TEMPLATE, + CONF_VARIABLES, CONF_WLED, DATA_CONFIGURED_ENTITIES, DATA_DOMAIN_ENTITIES, @@ -218,6 +219,7 @@ ), vol.Optional(CONF_UNAVAILABLE_POWER): vol.Coerce(float), vol.Optional(CONF_COMPOSITE): COMPOSITE_SCHEMA, + vol.Optional(CONF_VARIABLES): vol.Schema({cv.string: cv.string}), } diff --git a/custom_components/powercalc/sensors/power.py b/custom_components/powercalc/sensors/power.py index c2386e67f..3e8928e2e 100644 --- a/custom_components/powercalc/sensors/power.py +++ b/custom_components/powercalc/sensors/power.py @@ -490,7 +490,6 @@ def init_calculation_enabled_condition(self) -> None: template: Template | str = self._sensor_config.get(CONF_CALCULATION_ENABLED_CONDITION) # type: ignore if isinstance(template, str): - template = template.replace("[[entity]]", self.source_entity) template = Template(template, self.hass) self._calculation_enabled_condition = template diff --git a/custom_components/powercalc/strategy/factory.py b/custom_components/powercalc/strategy/factory.py index 29a14c0a8..95dc798c6 100644 --- a/custom_components/powercalc/strategy/factory.py +++ b/custom_components/powercalc/strategy/factory.py @@ -2,7 +2,7 @@ from collections.abc import Callable from decimal import Decimal -from typing import cast +from typing import Any, cast from homeassistant.const import CONF_CONDITION, CONF_ENTITIES, CONF_ENTITY_ID from homeassistant.core import HomeAssistant @@ -104,14 +104,11 @@ def _create_fixed( power = fixed_config.get(CONF_POWER) if power is None: power = fixed_config.get(CONF_POWER_TEMPLATE) - if isinstance(power, Template): - power.hass = self._hass + power = self._resolve_template(power) states_power = fixed_config.get(CONF_STATES_POWER) if states_power: - for p in states_power.values(): - if isinstance(p, Template): - p.hass = self._hass + states_power = {state: self._resolve_template(value) for state, value in states_power.items()} return FixedStrategy(source_entity, power, states_power) @@ -218,6 +215,20 @@ def _create_multi_switch(self, config: ConfigType, power_profile: PowerProfile | off_power=Decimal(off_power), ) + def _resolve_template(self, value: Any) -> Any: # noqa: ANN401 + """ + Process the input to ensure it is a Template if applicable. + Otherwise, return the original value. + """ + if isinstance(value, str) and value.startswith("{{"): + template = Template(value) + template.hass = self._hass + return template + if isinstance(value, Template): + value.hass = self._hass + return value + return value + @staticmethod def _get_strategy_config( strategy: CalculationStrategy, diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 9cb5419db..0d86fed64 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -81,6 +81,7 @@ nav: - 'library/library.md' - 'library/structure.md' - 'library/sub-profiles.md' + - 'library/custom-fields.md' - 'Device types': - 'library/device-types/index.md' - 'library/device-types/camera.md' diff --git a/docs/source/library/custom-fields.md b/docs/source/library/custom-fields.md new file mode 100644 index 000000000..416b3db64 --- /dev/null +++ b/docs/source/library/custom-fields.md @@ -0,0 +1,104 @@ +# Custom fields + +Sometimes there is a need to ask the user to provide some additional data for a profile. +This can be done by adding custom fields to the profile configuration. +During discovery flow, or when user adds from library their will be an additional step where the user can provide the custom fields. + +## Adding custom fields + +You can add one or more custom fields to a profile by adding a `fields` section to the profile configuration. + +```json +{ + "fields": { + "switch_entity": { + "label": "Switch entity", + "description": "Select the switch entity for your device", + "selector": { + "entity": { + "domain": "switch" + } + } + } + } +} +``` + +The key `switch_entity` is the key of the field. This can be referenced in the profile configuration using the `[[switch_entity]]` syntax. +After setup Powercalc will replace this with the value the user provided. + +`label` is the label of the field that will be shown to the user. +`description` is optional and is shown to the user below the field. +`selector` is the type of field. The configuration is similar to [HA Blueprints](https://www.home-assistant.io/docs/blueprint/selectors/). + +!!! note + Not all selectors are tested. Some might not be supported. `number` and `entity` are tested and should work. + +### Example number selector + +In the example below we have a profile that asks the user to provide a number. +The profile then calculates the power usage based on the number provided. + +```json +{ + "calculation_strategy": "fixed", + "fields": { + "num_switches": { + "label": "Number of switches", + "description": "Enter some number", + "selector": { + "number": { + "min": 0, + "max": 4, + "step": 1 + } + } + } + }, + "fixed_config": { + "power": "{{ [[num_switches]] * 0.20 }}" + } +} +``` + +When the user provides the number `2`, the template will be ``{{ 2 * 0.20 }}`` which will result in `0.40`. + +### Example entity selector + +In the example below we have a profile that asks the user to select a binary sensor. +The profile then calculates the power usage based on the state of the binary sensor. + +```json +{ + "calculation_strategy": "composite", + "fields": { + "some_entity": { + "label": "Some entity", + "description": "Select some entity", + "selector": { + "entity": { + "domain": "binary_sensor" + } + } + } + }, + "composite_config": [ + { + "condition": { + "condition": "state", + "entity_id": "[[some_entity]]", + "state": "on" + }, + "fixed": { + "power": 20 + } + }, + { + "fixed": { + "power": 10 + } + } + ] +} + +``` diff --git a/profile_library/model_schema.json b/profile_library/model_schema.json index c22d541ac..e0cd0bbb2 100644 --- a/profile_library/model_schema.json +++ b/profile_library/model_schema.json @@ -115,6 +115,49 @@ ], "description": "Whether to discover the profile by device or entity" }, + "fields": { + "type": "array", + "items": { + "type": "object", + "properties": { + "description": { + "type": "string", + "description": "Description of the field" + }, + "label": { + "type": "string", + "description": "Label of the field" + }, + "selector": { + "type": "object", + "properties": { + "entity": { + "type": "object", + "properties": { + "domain": { + "type": "string" + }, + "device_class": { + "type": "string" + } + } + }, + "number": { + "type": "object", + "properties": { + "min": { + "type": "number" + }, + "max": { + "type": "number" + } + } + } + } + } + } + } + }, "fixed_config": { "type": "object", "description": "Configuration for fixed calculation strategy", diff --git a/tests/config_flow/test_dynamic_field_builder.py b/tests/config_flow/test_dynamic_field_builder.py new file mode 100644 index 000000000..9e925936b --- /dev/null +++ b/tests/config_flow/test_dynamic_field_builder.py @@ -0,0 +1,83 @@ +from typing import Any + +from homeassistant.core import HomeAssistant +from homeassistant.helpers.selector import EntitySelector, NumberSelector + +from custom_components.powercalc.flow_helper.dynamic_field_builder import build_dynamic_field_schema +from custom_components.powercalc.power_profile.power_profile import PowerProfile + + +def test_build_schema(hass: HomeAssistant) -> None: + profile = create_power_profile( + hass, + { + "test1": { + "label": "Test 1", + "description": "Test 1", + "selector": { + "entity": { + "multiple": True, + "device_class": "power", + }, + }, + }, + "test2": { + "label": "Test 2", + "description": "Test 2", + "selector": { + "number": { + "min": 0, + "max": 60, + "step": 1, + "unit_of_measurement": "minutes", + "mode": "slider", + }, + }, + }, + }, + ) + schema = build_dynamic_field_schema(profile) + assert len(schema.schema) == 2 + assert "test1" in schema.schema + test1 = schema.schema["test1"] + assert isinstance(test1, EntitySelector) + assert test1.config == {"multiple": True, "device_class": ["power"]} + + assert "test2" in schema.schema + test2 = schema.schema["test2"] + assert isinstance(test2, NumberSelector) + assert test2.config == {"min": 0, "max": 60, "step": 1, "unit_of_measurement": "minutes", "mode": "slider"} + + +def test_omit_description(hass: HomeAssistant) -> None: + profile = create_power_profile( + hass, + { + "test1": { + "label": "Test 1", + "selector": { + "entity": { + "multiple": True, + "device_class": "power", + }, + }, + }, + }, + ) + schema = build_dynamic_field_schema(profile) + + schema_keys = list(schema.schema.keys()) + assert schema_keys[schema_keys.index("test1")].description == "Test 1" + + +def create_power_profile(hass: HomeAssistant, fields: dict[str, Any]) -> PowerProfile: + return PowerProfile( + hass, + "test", + "test", + "", + { + "name": "test", + "fields": fields, + }, + ) diff --git a/tests/config_flow/test_virtual_power_library.py b/tests/config_flow/test_virtual_power_library.py index 2f8c45b83..90ecff172 100644 --- a/tests/config_flow/test_virtual_power_library.py +++ b/tests/config_flow/test_virtual_power_library.py @@ -1,9 +1,18 @@ +import logging + +import pytest import voluptuous as vol from homeassistant import data_entry_flow -from homeassistant.const import CONF_ENTITY_ID, STATE_ON +from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, STATE_ON from homeassistant.core import HomeAssistant from homeassistant.helpers.selector import SelectSelector +from custom_components.powercalc import ( + CONF_CREATE_ENERGY_SENSOR, + CONF_CREATE_UTILITY_METERS, + CONF_ENERGY_INTEGRATION_METHOD, + DEFAULT_ENERGY_INTEGRATION_METHOD, +) from custom_components.powercalc.config_flow import ( CONF_CONFIRM_AUTODISCOVERED_MODEL, Step, @@ -13,11 +22,12 @@ CONF_MODE, CONF_MODEL, CONF_SENSOR_TYPE, + CONF_VARIABLES, CalculationStrategy, SensorType, ) from custom_components.test.light import MockLight -from tests.common import create_mock_light_entity +from tests.common import create_mock_light_entity, get_test_config_dir from tests.config_flow.common import ( DEFAULT_UNIQUE_ID, create_mock_entry, @@ -183,3 +193,67 @@ async def test_change_manufacturer_model_from_options_flow(hass: HomeAssistant) assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert entry.data[CONF_MANUFACTURER] == "signify" assert entry.data[CONF_MODEL] == "LWB010" + + +async def test_profile_with_custom_fields( + hass: HomeAssistant, + mock_entity_with_model_information: MockEntityWithModel, + caplog: pytest.LogCaptureFixture, +) -> None: + caplog.set_level(logging.ERROR) + + hass.config.config_dir = get_test_config_dir() + mock_entity_with_model_information( + "sensor.test", + "test", + "custom-fields", + unique_id=DEFAULT_UNIQUE_ID, + ) + + result = await select_menu_item(hass, Step.MENU_LIBRARY) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == Step.VIRTUAL_POWER + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_ENTITY_ID: "sensor.test"}, + ) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == Step.LIBRARY + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_CONFIRM_AUTODISCOVERED_MODEL: True}, + ) + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == Step.LIBRARY_CUSTOM_FIELDS + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"some_entity": "sensor.foobar"}, + ) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == Step.POWER_ADVANCED + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {}, + ) + + assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY + assert result["data"] == { + CONF_CREATE_ENERGY_SENSOR: True, + CONF_CREATE_UTILITY_METERS: False, + CONF_ENERGY_INTEGRATION_METHOD: DEFAULT_ENERGY_INTEGRATION_METHOD, + CONF_ENTITY_ID: "sensor.test", + CONF_NAME: "test", + CONF_MANUFACTURER: "test", + CONF_MODEL: "custom-fields", + CONF_SENSOR_TYPE: SensorType.VIRTUAL_POWER, + CONF_VARIABLES: { + "some_entity": "sensor.foobar", + }, + } + + assert not caplog.records diff --git a/tests/power_profile/device_types/test_custom_fields.py b/tests/power_profile/device_types/test_custom_fields.py new file mode 100644 index 000000000..8abedf016 --- /dev/null +++ b/tests/power_profile/device_types/test_custom_fields.py @@ -0,0 +1,108 @@ +import logging + +import pytest +from homeassistant.const import CONF_ENTITY_ID, CONF_NAME, STATE_ON +from homeassistant.core import HomeAssistant + +from custom_components.powercalc.const import CONF_CUSTOM_MODEL_DIRECTORY, CONF_MANUFACTURER, CONF_MODEL, CONF_VARIABLES, DUMMY_ENTITY_ID +from custom_components.powercalc.power_profile.error import LibraryError +from custom_components.powercalc.power_profile.library import ProfileLibrary +from tests.common import get_test_config_dir, get_test_profile_dir, run_powercalc_setup + + +async def test_custom_field_variables_from_yaml_config(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: + """Test custom field variables can be passed from YAML configuration""" + caplog.set_level(logging.ERROR) + hass.config.config_dir = get_test_config_dir() + + hass.states.async_set("sensor.test", STATE_ON) + await hass.async_block_till_done() + + await run_powercalc_setup( + hass, + { + CONF_ENTITY_ID: DUMMY_ENTITY_ID, + CONF_NAME: "Test", + CONF_MANUFACTURER: "test", + CONF_MODEL: "custom-fields", + CONF_VARIABLES: { + "some_entity": "sensor.test", + }, + }, + ) + + assert not caplog.records + + assert hass.states.get("sensor.test_power").state == "20.00" + + +async def test_validation_fails_when_not_all_variables_passed(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: + """Test error is logged when not all variables are passed, when setting up profile with custom fields""" + caplog.set_level(logging.ERROR) + hass.config.config_dir = get_test_config_dir() + + await run_powercalc_setup( + hass, + { + CONF_ENTITY_ID: DUMMY_ENTITY_ID, + CONF_NAME: "Test", + CONF_MANUFACTURER: "test", + CONF_MODEL: "custom-fields", + CONF_VARIABLES: {}, + }, + ) + + assert "Missing variables for fields: some_entity" in caplog.text + + +@pytest.mark.parametrize( + "custom_field_keys,variables,valid", + [ + (["var1"], {"var1": "sensor.foo"}, True), + ([], {}, True), + (["var1", "var2"], {"var1": "sensor.test"}, False), + (["var1"], {"var3": "sensor.test"}, False), + (["var1"], {"var1": "sensor.test", "var2": "sensor.test"}, False), + ], +) +async def test_validate_variables( + hass: HomeAssistant, + custom_field_keys: list, + variables: dict, + valid: bool, +) -> None: + lib = await ProfileLibrary.factory(hass) + + custom_fields = {key: {"name": "test", "selector": {"number": {}}} for key in custom_field_keys} + json_data = {"fields": custom_fields} + + if not valid: + with pytest.raises(LibraryError): + lib.validate_variables( + json_data, + variables, + ) + return + + lib.validate_variables( + json_data, + variables, + ) + + +async def test_custom_fields_with_template(hass: HomeAssistant) -> None: + """Test custom field variables can be passed from YAML configuration""" + hass.states.async_set("switch.test", STATE_ON) + await hass.async_block_till_done() + await run_powercalc_setup( + hass, + { + CONF_ENTITY_ID: "switch.test", + CONF_CUSTOM_MODEL_DIRECTORY: get_test_profile_dir("custom-fields-template"), + CONF_VARIABLES: { + "num_switches": 4, + }, + }, + ) + + assert hass.states.get("sensor.test_power").state == "0.80" diff --git a/tests/power_profile/loader/test_local.py b/tests/power_profile/loader/test_local.py index 0b0fcb675..d0a9376c2 100644 --- a/tests/power_profile/loader/test_local.py +++ b/tests/power_profile/loader/test_local.py @@ -12,21 +12,21 @@ async def test_broken_lib_by_identical_model_alias(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: - loader = LocalLoader(hass, get_test_config_dir("powercalc_profiles/double-model")) + loader = LocalLoader(hass, get_test_profile_dir("double-model")) with caplog.at_level(logging.ERROR): await loader.initialize() assert "Double entry manufacturer/model in custom library:" in caplog.text async def test_broken_lib_by_identical_alias_alias(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: - loader = LocalLoader(hass, get_test_config_dir("powercalc_profiles/double-alias")) + loader = LocalLoader(hass, get_test_profile_dir("double-alias")) with caplog.at_level(logging.ERROR): await loader.initialize() assert "Double entry manufacturer/model in custom library" in caplog.text async def test_broken_lib_by_missing_model_json(hass: HomeAssistant, caplog: pytest.LogCaptureFixture) -> None: - loader = LocalLoader(hass, get_test_config_dir("powercalc_profiles/missing-model-json")) + loader = LocalLoader(hass, get_test_profile_dir("missing-model-json")) with caplog.at_level(logging.ERROR): await loader.initialize() assert "model.json should exist in" in caplog.text diff --git a/tests/power_profile/test_power_profile.py b/tests/power_profile/test_power_profile.py index 8e9758f00..594f0ccb0 100644 --- a/tests/power_profile/test_power_profile.py +++ b/tests/power_profile/test_power_profile.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant, State @@ -309,3 +311,73 @@ async def test_device_type(hass: HomeAssistant) -> None: ) assert power_profile.device_type == DeviceType.SMART_SPEAKER + + +@pytest.mark.parametrize( + "json_data,expected_result", + [ + ( + { + "calculation_strategy": CalculationStrategy.FIXED, + }, + True, + ), + ( + { + "calculation_strategy": CalculationStrategy.LINEAR, + }, + True, + ), + ( + { + "calculation_strategy": CalculationStrategy.COMPOSITE, + "fields": { + "foo": { + "label": "Foo", + "selector": {"entity": {}}, + }, + }, + }, + True, + ), + ( + { + "calculation_strategy": CalculationStrategy.FIXED, + "fixed_config": { + "power": 50, + }, + }, + False, + ), + ( + { + "calculation_strategy": CalculationStrategy.LINEAR, + "linear_config": { + "min_power": 50, + "max_power": 100, + }, + }, + False, + ), + ( + { + "calculation_strategy": CalculationStrategy.MULTI_SWITCH, + "multi_switch_config": { + "power": 0.725, + "power_off": 0.225, + }, + }, + True, + ), + ], +) +async def test_needs_user_configuration(hass: HomeAssistant, json_data: dict[str, Any], expected_result: bool) -> None: + power_profile = PowerProfile( + hass, + manufacturer="test", + model="test", + directory=get_test_profile_dir("media_player"), + json_data=json_data, + ) + + assert await power_profile.needs_user_configuration == expected_result diff --git a/tests/strategy/test_lut.py b/tests/strategy/test_lut.py index 55aa71955..b525536ea 100644 --- a/tests/strategy/test_lut.py +++ b/tests/strategy/test_lut.py @@ -22,7 +22,7 @@ from custom_components.powercalc.strategy.strategy_interface import ( PowerCalculationStrategyInterface, ) -from tests.common import get_test_config_dir, run_powercalc_setup +from tests.common import get_test_profile_dir, run_powercalc_setup from tests.strategy.common import create_source_entity @@ -286,7 +286,7 @@ async def test_fallback_to_non_gzipped_file(hass: HomeAssistant) -> None: hass, "test", "test", - custom_profile_dir=get_test_config_dir("powercalc_profiles/lut-non-gzipped"), + custom_profile_dir=get_test_profile_dir("lut-non-gzipped"), ) await _calculate_and_assert_power( strategy, diff --git a/tests/testing_config/powercalc/profiles/test/custom-fields/model.json b/tests/testing_config/powercalc/profiles/test/custom-fields/model.json new file mode 100644 index 000000000..c6640c5db --- /dev/null +++ b/tests/testing_config/powercalc/profiles/test/custom-fields/model.json @@ -0,0 +1,35 @@ +{ + "name": "My device", + "measure_method": "manual", + "measure_device": "xx", + "device_type": "generic_iot", + "calculation_strategy": "composite", + "fields": { + "some_entity": { + "label": "Some entity", + "description": "Select some entity", + "selector": { + "entity": { + "domain": "sensor" + } + } + } + }, + "composite_config": [ + { + "condition": { + "condition": "state", + "entity_id": "[[some_entity]]", + "state": "on" + }, + "fixed": { + "power": 20 + } + }, + { + "fixed": { + "power": 10 + } + } + ] +} diff --git a/tests/testing_config/powercalc_profiles/custom-fields-template/model.json b/tests/testing_config/powercalc_profiles/custom-fields-template/model.json new file mode 100644 index 000000000..c8ad709ce --- /dev/null +++ b/tests/testing_config/powercalc_profiles/custom-fields-template/model.json @@ -0,0 +1,23 @@ +{ + "name": "My device", + "measure_method": "manual", + "measure_device": "xx", + "device_type": "generic_iot", + "calculation_strategy": "fixed", + "fields": { + "num_switches": { + "label": "Number of switches", + "description": "Enter some number", + "selector": { + "number": { + "min": 0, + "max": 4, + "step": 1 + } + } + } + }, + "fixed_config": { + "power": "{{ [[num_switches]] * 0.20 }}" + } +}