Skip to content

Commit

Permalink
Create an own class for a cached HTTP client (#314)
Browse files Browse the repository at this point in the history
This avoids memory leaks for DenonAVR class because of cached HTTP responses.
  • Loading branch information
ol-iver authored Nov 10, 2024
1 parent 8368c7b commit c76907f
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 100 deletions.
177 changes: 120 additions & 57 deletions denonavr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

import attr
import httpx
from defusedxml.ElementTree import fromstring
from defusedxml import DefusedXmlException
from defusedxml.ElementTree import ParseError, fromstring

from .appcommand import AppCommandCmd
from .const import (
Expand Down Expand Up @@ -83,16 +84,86 @@ def telnet_event_map_factory() -> Dict[str, List]:
return dict(event_map)


@attr.s(auto_attribs=True, hash=False, on_setattr=DENON_ATTR_SETATTR)
@attr.s(auto_attribs=True, hash=False)
class HTTPXAsyncClient:
"""Perform cached HTTP calls with httpx.AsyncClient."""

client_getter: Callable[[], httpx.AsyncClient] = attr.ib(
validator=attr.validators.is_callable(),
default=get_default_async_client,
init=False,
)

def __hash__(self) -> int:
"""Hash the class using its ID that caching works."""
return id(self)

@cache_result
@async_handle_receiver_exceptions
async def async_get(
self,
url: str,
timeout: float,
read_timeout: float,
*,
cache_id: Hashable = None,
) -> httpx.Response:
"""Call GET endpoint of Denon AVR receiver asynchronously."""
client = self.client_getter()
try:
res = await client.get(
url, timeout=httpx.Timeout(timeout, read=read_timeout)
)
res.raise_for_status()
finally:
# Close the default AsyncClient but keep custom clients open
if self.is_default_async_client():
await client.aclose()

return res

@cache_result
@async_handle_receiver_exceptions
async def async_post(
self,
url: str,
timeout: float,
read_timeout: float,
*,
content: Optional[bytes] = None,
data: Optional[Dict] = None,
cache_id: Hashable = None,
) -> httpx.Response:
"""Call GET endpoint of Denon AVR receiver asynchronously."""
client = self.client_getter()
try:
res = await client.post(
url,
content=content,
data=data,
timeout=httpx.Timeout(timeout, read=read_timeout),
)
res.raise_for_status()
finally:
# Close the default AsyncClient but keep custom clients open
if self.is_default_async_client():
await client.aclose()

return res

def is_default_async_client(self) -> bool:
"""Check if default httpx.AsyncClient getter is used."""
return self.client_getter is get_default_async_client


@attr.s(auto_attribs=True, on_setattr=DENON_ATTR_SETATTR)
class DenonAVRApi:
"""Perform API calls to Denon AVR REST interface."""

host: str = attr.ib(converter=str, default="localhost")
port: int = attr.ib(converter=int, default=80)
timeout: httpx.Timeout = attr.ib(
validator=attr.validators.instance_of(httpx.Timeout),
default=httpx.Timeout(2.0, read=15.0),
)
timeout: float = attr.ib(converter=float, default=2.0)
read_timeout: float = attr.ib(converter=float, default=15.0)
_appcommand_update_tags: Tuple[AppCommandCmd] = attr.ib(
validator=attr.validators.deep_iterable(
attr.validators.instance_of(AppCommandCmd),
Expand All @@ -107,104 +178,100 @@ class DenonAVRApi:
),
default=attr.Factory(tuple),
)
async_client_getter: Callable[[], httpx.AsyncClient] = attr.ib(
validator=attr.validators.is_callable(),
default=get_default_async_client,
httpx_async_client: HTTPXAsyncClient = attr.ib(
validator=attr.validators.instance_of(HTTPXAsyncClient),
default=attr.Factory(HTTPXAsyncClient),
init=False,
)

def __hash__(self) -> int:
"""
Hash the class in a custom way that caching works.
It should react on changes of host and port.
"""
return hash((self.host, self.port))

@async_handle_receiver_exceptions
async def async_get(
self, request: str, port: Optional[int] = None
self,
request: str,
*,
port: Optional[int] = None,
cache_id: Hashable = None,
) -> httpx.Response:
"""Call GET endpoint of Denon AVR receiver asynchronously."""
# Use default port of the receiver if no different port is specified
port = port if port is not None else self.port

endpoint = f"http://{self.host}:{port}{request}"

client = self.async_client_getter()
try:
res = await client.get(endpoint, timeout=self.timeout)
res.raise_for_status()
finally:
# Close the default AsyncClient but keep custom clients open
if self.is_default_async_client():
await client.aclose()

return res
return await self.httpx_async_client.async_get(
endpoint, self.timeout, self.read_timeout, cache_id=cache_id
)

@async_handle_receiver_exceptions
async def async_post(
self,
request: str,
*,
content: Optional[bytes] = None,
data: Optional[Dict] = None,
port: Optional[int] = None,
cache_id: Hashable = None,
) -> httpx.Response:
"""Call POST endpoint of Denon AVR receiver asynchronously."""
# Use default port of the receiver if no different port is specified
port = port if port is not None else self.port

endpoint = f"http://{self.host}:{port}{request}"

client = self.async_client_getter()
try:
res = await client.post(
endpoint, content=content, data=data, timeout=self.timeout
)
res.raise_for_status()
finally:
# Close the default AsyncClient but keep custom clients open
if self.is_default_async_client():
await client.aclose()

return res
return await self.httpx_async_client.async_post(
endpoint,
self.timeout,
self.read_timeout,
content=content,
data=data,
cache_id=cache_id,
)

@async_handle_receiver_exceptions
async def async_get_command(self, request: str) -> str:
"""Send HTTP GET command to Denon AVR receiver asynchronously."""
# HTTP GET to endpoint
res = await self.async_get(request)
# Return text
return res.text

@cache_result
@async_handle_receiver_exceptions
async def async_get_xml(
self, request: str, cache_id: Hashable = None
self, request: str, *, cache_id: Hashable = None
) -> ET.Element:
"""Return XML data from HTTP GET endpoint asynchronously."""
# HTTP GET to endpoint
res = await self.async_get(request)
res = await self.async_get(request, cache_id=cache_id)
# create ElementTree
xml_root = fromstring(res.text)
try:
xml_root = fromstring(res.text)
except (
ET.ParseError,
DefusedXmlException,
ParseError,
UnicodeDecodeError,
) as err:
raise AvrInvalidResponseError(f"XMLParseError: {err}", request) from err
# Check validity of XML
self.check_xml_validity(request, xml_root)
# Return ElementTree element
return xml_root

@cache_result
@async_handle_receiver_exceptions
async def async_post_appcommand(
self, request: str, cmds: Tuple[AppCommandCmd], cache_id: Hashable = None
self, request: str, cmds: Tuple[AppCommandCmd], *, cache_id: Hashable = None
) -> ET.Element:
"""Return XML from Appcommand(0300) endpoint asynchronously."""
# Prepare XML body for POST call
content = self.prepare_appcommand_body(cmds)
_LOGGER.debug("Content for %s endpoint: %s", request, content)
# HTTP POST to endpoint
res = await self.async_post(request, content=content)
res = await self.async_post(request, content=content, cache_id=cache_id)
# create ElementTree
xml_root = fromstring(res.text)
try:
xml_root = fromstring(res.text)
except (
ET.ParseError,
DefusedXmlException,
ParseError,
UnicodeDecodeError,
) as err:
raise AvrInvalidResponseError(f"XMLParseError: {err}", request) from err
# Check validity of XML
self.check_xml_validity(request, xml_root)
# Add query tags to result
Expand Down Expand Up @@ -350,10 +417,6 @@ def prepare_appcommand_body(cmd_list: Tuple[AppCommandCmd]) -> bytes:

return body_bytes

def is_default_async_client(self) -> bool:
"""Check if default httpx.AsyncClient getter is used."""
return self.async_client_getter is get_default_async_client


class DenonAVRTelnetProtocol(asyncio.Protocol):
"""Protocol for the Denon AVR Telnet interface."""
Expand Down
37 changes: 5 additions & 32 deletions denonavr/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@
import inspect
import logging
import time
import xml.etree.ElementTree as ET
from functools import wraps
from typing import Callable, TypeVar

import httpx
from asyncstdlib import lru_cache
from defusedxml import DefusedXmlException
from defusedxml.ElementTree import ParseError

from .exceptions import (
AvrForbiddenError,
Expand All @@ -33,12 +30,7 @@


def async_handle_receiver_exceptions(func: Callable[..., AnyT]) -> Callable[..., AnyT]:
"""
Handle exceptions raised when calling a Denon AVR endpoint asynchronously.
The decorated function must either have a string variable as second
argument or as "request" keyword argument.
"""
"""Handle exceptions raised when calling a Denon AVR endpoint asynchronously."""

@wraps(func)
async def wrapper(*args, **kwargs):
Expand All @@ -64,48 +56,29 @@ async def wrapper(*args, **kwargs):
raise AvrInvalidResponseError(
f"RemoteProtocolError: {err}", err.request
) from err
except (
ET.ParseError,
DefusedXmlException,
ParseError,
UnicodeDecodeError,
) as err:
_LOGGER.debug(
"Defusedxml parse error on request %s: %s", (args, kwargs), err
)
raise AvrInvalidResponseError(
f"XMLParseError: {err}", (args, kwargs)
) from err

return wrapper


def cache_result(func: Callable[..., AnyT]) -> Callable[..., AnyT]:
"""
Decorate a function to cache its results with an lru_cache of maxsize 16.
Decorate a function to cache its results with an lru_cache of maxsize 32.
This decorator also sets an "cache_id" keyword argument if it is not set yet.
When an exception occurs it clears lru_cache to prevent memory leaks in
home-assistant when receiver instances are created and deleted right
away in case the device is offline on setup.
"""
if inspect.signature(func).parameters.get("cache_id") is None:
raise AttributeError(
f"Function {func} does not have a 'cache_id' keyword parameter"
)

lru_decorator = lru_cache(maxsize=16)
lru_decorator = lru_cache(maxsize=32)
cached_func = lru_decorator(func)

@wraps(func)
async def wrapper(*args, **kwargs):
if kwargs.get("cache_id") is None:
kwargs["cache_id"] = time.time()
try:
return await cached_func(*args, **kwargs)
except Exception as err:
_LOGGER.debug("Exception raised, clearing cache: %s", err)
cached_func.cache_clear()
raise

return await cached_func(*args, **kwargs)

return wrapper
2 changes: 1 addition & 1 deletion denonavr/denonavr.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def set_async_client_getter(
"""
if not callable(async_client_getter):
raise AvrCommandError("Provided object is not callable")
self._device.api.async_client_getter = async_client_getter
self._device.api.httpx_async_client.client_getter = async_client_getter

async def async_dynamic_eq_off(self) -> None:
"""Turn DynamicEQ off."""
Expand Down
13 changes: 6 additions & 7 deletions denonavr/foundation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Dict, List, Optional, Union

import attr
import httpx

from .api import DenonAVRApi, DenonAVRTelnetApi
from .appcommand import AppCommandCmd, AppCommands
Expand Down Expand Up @@ -146,15 +145,15 @@ async def async_setup(self) -> None:
_LOGGER.debug("Starting device setup")
# Reduce read timeout during receiver identification
# deviceinfo endpoint takes very long to return 404
timeout = self.api.timeout
self.api.timeout = httpx.Timeout(self.api.timeout.connect)
read_timeout = self.api.read_timeout
self.api.read_timeout = self.api.timeout
try:
_LOGGER.debug("Identifying receiver")
await self.async_identify_receiver()
_LOGGER.debug("Getting device info")
await self.async_get_device_info()
finally:
self.api.timeout = timeout
self.api.read_timeout = read_timeout
_LOGGER.debug("Identifying update method")
await self.async_identify_update_method()

Expand Down Expand Up @@ -323,7 +322,7 @@ async def async_identify_update_method(self) -> None:
self._set_friendly_name(xml)

async def async_verify_avr_2016_update_method(
self, cache_id: Hashable = None
self, *, cache_id: Hashable = None
) -> None:
"""Verify if avr 2016 update method is working."""
# Nothing to do if Appcommand.xml interface is not supported
Expand Down Expand Up @@ -833,9 +832,9 @@ def set_api_timeout(
) -> float:
"""Change API timeout on timeout changes too."""
# First change _device.api.host then return value
timeout = httpx.Timeout(value, read=max(value, 15.0))
# pylint: disable=protected-access
instance._device.api.timeout = timeout
instance._device.api.timeout = value
instance._device.api.read_timeout = max(value, 15.0)
instance._device.telnet_api.timeout = value
return value

Expand Down
Loading

0 comments on commit c76907f

Please sign in to comment.