Skip to content

Commit

Permalink
Merge pull request #2852 from bramstroker/feat/multi-manufacturer-ali…
Browse files Browse the repository at this point in the history
…ases

Support same alias for multiple manufacturers in library loading code
  • Loading branch information
bramstroker authored Dec 27, 2024
2 parents 210d0f3 + 166d44a commit b5450e1
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 126 deletions.
1 change: 0 additions & 1 deletion custom_components/powercalc/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
DATA_ENTITIES = "entities"
DATA_GROUP_ENTITIES = "group_entities"
DATA_USED_UNIQUE_IDS = "used_unique_ids"
DATA_PROFILE_LIBRARY = "profile_library"
DATA_STANDBY_POWER_SENSORS = "standby_power_sensors"

ENTRY_DATA_ENERGY_ENTITY = "_energy_entity"
Expand Down
15 changes: 3 additions & 12 deletions custom_components/powercalc/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,15 @@ async def discover_entity(
) -> list[PowerProfile] | None:
"""Discover a single entity in Powercalc library and start the discovery flow if supported."""

library = await self._get_library()
if source_entity.entity_entry is None: # pragma: no cover
return None

if self.is_wled_light(model_info, source_entity.entity_entry):
await self.init_wled_flow(model_info, source_entity)
return None

manufacturer = await library.find_manufacturer(model_info)
if not manufacturer:
_LOGGER.debug(
"%s: Manufacturer not found in library, skipping discovery",
source_entity.entity_entry.entity_id,
)
return None

models = await library.find_models(manufacturer, model_info)
library = await self._get_library()
models = await library.find_models(model_info)
if not models:
_LOGGER.debug(
"%s: Model not found in library, skipping discovery",
Expand All @@ -142,8 +134,7 @@ async def discover_entity(
return None

power_profiles = []
for model in models:
model_info = ModelInfo(manufacturer, model)
for model_info in models:
profile = await get_power_profile(self.hass, {}, model_info=model_info)
if not profile: # pragma: no cover
continue
Expand Down
46 changes: 46 additions & 0 deletions custom_components/powercalc/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import logging
import os.path
import uuid
from collections.abc import Callable, Coroutine
from decimal import Decimal
from functools import wraps
from typing import Any, TypeVar

from homeassistant.const import CONF_UNIQUE_ID
from homeassistant.helpers.template import Template
Expand Down Expand Up @@ -69,3 +72,46 @@ def get_or_create_unique_id(
return f"pc_{source_unique_id}"

return str(uuid.uuid4())


P = TypeVar("P") # Used for positional and keyword argument types
R = TypeVar("R") # Used for return type


def make_hashable(arg: Any) -> Any: # noqa: ANN401
"""Convert unhashable arguments to hashable equivalents."""
if isinstance(arg, set):
return frozenset(arg)
if isinstance(arg, list):
return tuple(arg)
if isinstance(arg, dict):
return frozenset((key, make_hashable(value)) for key, value in arg.items())
return arg


def async_cache(func: Callable[..., Coroutine[Any, Any, R]]) -> Callable[..., Coroutine[Any, Any, R]]:
"""
A decorator to cache results of an async function based on its arguments.
Args:
func: The asynchronous function to decorate.
Returns:
A decorated asynchronous function with caching.
"""
cache: dict[tuple[tuple[Any, ...], frozenset], R] = {}

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> R: # noqa: ANN401
# Make arguments hashable
hashable_args = tuple(make_hashable(arg) for arg in args)
hashable_kwargs = frozenset((key, make_hashable(value)) for key, value in kwargs.items())
cache_key = (hashable_args, hashable_kwargs)

if cache_key in cache:
return cache[cache_key]
result = await func(*args, **kwargs)
cache[cache_key] = result
return result

return wrapper
76 changes: 41 additions & 35 deletions custom_components/powercalc/power_profile/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import NamedTuple

from homeassistant.core import HomeAssistant
from homeassistant.helpers.singleton import singleton

from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DATA_PROFILE_LIBRARY, DOMAIN, DOMAIN_CONFIG
from custom_components.powercalc.const import CONF_DISABLE_LIBRARY_DOWNLOAD, DOMAIN, DOMAIN_CONFIG

from .error import LibraryError
from .loader.composite import CompositeLoader
Expand All @@ -34,19 +35,14 @@ async def initialize(self) -> None:
await self._loader.initialize()

@staticmethod
@singleton("powercalc_library")
async def factory(hass: HomeAssistant) -> ProfileLibrary:
"""Creates and loads the profile library
Makes sure it is only loaded once and instance is saved in hass data registry.
"""
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}

if DATA_PROFILE_LIBRARY in hass.data[DOMAIN]:
return hass.data[DOMAIN][DATA_PROFILE_LIBRARY] # type: ignore

Creates and loads the profile library.
Make sure we have a single instance throughout the application.
"""
library = ProfileLibrary(hass, ProfileLibrary.create_loader(hass))
await library.initialize()
hass.data[DOMAIN][DATA_PROFILE_LIBRARY] = library
return library

@staticmethod
Expand All @@ -61,7 +57,8 @@ def create_loader(hass: HomeAssistant) -> Loader:
if os.path.exists(data_dir)
]

global_config = hass.data[DOMAIN].get(DOMAIN_CONFIG, {})
domain_config = hass.data.get(DOMAIN, {})
global_config = domain_config.get(DOMAIN_CONFIG, {})
disable_library_download: bool = bool(global_config.get(CONF_DISABLE_LIBRARY_DOWNLOAD, False))
if not disable_library_download:
loaders.append(RemoteLoader(hass))
Expand All @@ -78,17 +75,22 @@ async def get_manufacturer_listing(self, entity_domain: str | None = None) -> li
async def get_model_listing(self, manufacturer: str, entity_domain: str | None = None) -> list[str]:
"""Get listing of available models for a given manufacturer."""

resolved_manufacturer = await self._loader.find_manufacturer(manufacturer)
if not resolved_manufacturer:
resolved_manufacturers = await self._loader.find_manufacturers(manufacturer)
if not resolved_manufacturers:
return []
device_types = get_device_types_from_domain(entity_domain) if entity_domain else None
cache_key = f"{resolved_manufacturer}/{device_types}"
cached_models = self._manufacturer_models.get(cache_key)
if cached_models:
return cached_models
models = await self._loader.get_model_listing(resolved_manufacturer, device_types)
self._manufacturer_models[cache_key] = sorted(models)
return self._manufacturer_models[cache_key]
all_models: list[str] = []
for manufacturer in resolved_manufacturers:
cache_key = f"{manufacturer}/{device_types}"
cached_models = self._manufacturer_models.get(cache_key)
if cached_models:
all_models.extend(cached_models)
continue
models = await self._loader.get_model_listing(manufacturer, device_types)
self._manufacturer_models[cache_key] = sorted(models)
all_models.extend(models)

return all_models

async def get_profile(
self,
Expand Down Expand Up @@ -116,30 +118,24 @@ async def create_power_profile(
) -> PowerProfile:
"""Create a power profile object from the model JSON data."""

manufacturer = model_info.manufacturer
model = model_info.model
if not custom_directory:
manufacturer = await self.find_manufacturer(model_info) # type: ignore
if manufacturer is None:
raise LibraryError(f"Manufacturer {model_info.manufacturer} not found")

models = await self.find_models(manufacturer, model_info)
models = await self.find_models(model_info)
if not models:
raise LibraryError(f"Model {manufacturer} {model} not found")
model = next(iter(models))
raise LibraryError(f"Model {model_info.manufacturer} {model_info.model} not found")
model_info = next(iter(models))

json_data, directory = await self._load_model_data(manufacturer, model, custom_directory)
json_data, directory = await self._load_model_data(model_info.manufacturer, model_info.model, custom_directory)
if linked_profile := json_data.get("linked_lut"):
linked_manufacturer, linked_model = linked_profile.split("/")
_, directory = await self._load_model_data(linked_manufacturer, linked_model, custom_directory)

return await self._create_power_profile_instance(manufacturer, model, directory, json_data)
return await self._create_power_profile_instance(model_info.manufacturer, model_info.model, directory, json_data)

async def find_manufacturer(self, model_info: ModelInfo) -> str | None:
async def find_manufacturers(self, manufacturer: str) -> set[str]:
"""Resolve the manufacturer, either from the model info or by loading it."""
return await self._loader.find_manufacturer(model_info.manufacturer)
return await self._loader.find_manufacturers(manufacturer)

async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str]:
async def find_models(self, model_info: ModelInfo) -> set[ModelInfo]:
"""Resolve the model identifier, searching for it if no custom directory is provided."""
search: set[str] = set()
for model_identifier in (model_info.model_id, model_info.model):
Expand All @@ -155,7 +151,17 @@ async def find_models(self, manufacturer: str, model_info: ModelInfo) -> set[str
if "/" in model_identifier:
search.update(model_identifier.split("/"))

return set(await self._loader.find_model(manufacturer, search))
manufacturers = await self._loader.find_manufacturers(model_info.manufacturer)
found_models: set[ModelInfo] = set()
if not manufacturers:
return found_models

for manufacturer in manufacturers:
models = await self._loader.find_model(manufacturer, search)
if models:
found_models.update(ModelInfo(manufacturer, model) for model in models)

return found_models

async def _load_model_data(self, manufacturer: str, model: str, custom_directory: str | None) -> tuple[dict, str]:
"""Load the model data from the appropriate directory."""
Expand Down
17 changes: 9 additions & 8 deletions custom_components/powercalc/power_profile/loader/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -

return {manufacturer for loader in self.loaders for manufacturer in await loader.get_manufacturer_listing(device_types)}

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available. Also must check aliases."""

search = search.lower()
found_manufacturers = set()
for loader in self.loaders:
manufacturer = await loader.find_manufacturer(search)
if manufacturer:
return manufacturer
manufacturers = await loader.find_manufacturers(search)
if manufacturers:
found_manufacturers.update(manufacturers)

return None
return found_manufacturers

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of available models for a given manufacturer."""
Expand All @@ -42,11 +43,11 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] |

return None

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Find the model in the library."""

models = []
models = set()
for loader in self.loaders:
models.extend(await loader.find_model(manufacturer, search))
models.update(await loader.find_model(manufacturer, search))

return models
12 changes: 6 additions & 6 deletions custom_components/powercalc/power_profile/loader/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -

return manufacturers

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available."""

_search = search.lower()
manufacturer_list = self._manufacturer_model_listing.keys()
if _search in manufacturer_list:
return _search
return {_search}

return None
return set()

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of available models for a given manufacturer.
Expand Down Expand Up @@ -104,19 +104,19 @@ async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] |
model_json = lib_model.json_data
return model_json, model_path

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Find a model for a given manufacturer. Also must check aliases."""
_manufacturer = manufacturer.lower()

models = self._manufacturer_model_listing.get(_manufacturer)
if not models:
_LOGGER.info("Manufacturer does not exist in custom library: %s", _manufacturer)
return []
return set()

search_lower = {phrase.lower() for phrase in search}

profile = next((models[model] for model in models if model.lower() in search_lower), None)
return [profile.model] if profile else []
return {profile.model} if profile else set()

def _load_custom_library(self) -> None:
"""Loading custom models and aliases from file system.
Expand Down
4 changes: 2 additions & 2 deletions custom_components/powercalc/power_profile/loader/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def initialize(self) -> None:
async def get_manufacturer_listing(self, device_types: set[DeviceType] | None) -> set[str]:
"""Get listing of possible manufacturers."""

async def find_manufacturer(self, search: str) -> str | None:
async def find_manufacturers(self, search: str) -> set[str]:
"""Check if a manufacturer is available. Also must check aliases."""

async def get_model_listing(self, manufacturer: str, device_types: set[DeviceType] | None) -> set[str]:
Expand All @@ -19,5 +19,5 @@ async def get_model_listing(self, manufacturer: str, device_types: set[DeviceTyp
async def load_model(self, manufacturer: str, model: str) -> tuple[dict, str] | None:
"""Load and optionally download a model profile."""

async def find_model(self, manufacturer: str, search: set[str]) -> list[str]:
async def find_model(self, manufacturer: str, search: set[str]) -> set[str]:
"""Check if a model is available. Also must check aliases."""
Loading

0 comments on commit b5450e1

Please sign in to comment.