Skip to content

Commit

Permalink
Add support for passing multiple addresses to the client (#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Dec 12, 2023
1 parent 4668b1f commit de1d084
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 93 deletions.
23 changes: 14 additions & 9 deletions aioesphomeapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def __init__(
zeroconf_instance: ZeroconfInstanceType | None = None,
noise_psk: str | None = None,
expected_name: str | None = None,
addresses: list[str] | None = None,
) -> None:
"""Create a client, this object is shared across sessions.
Expand All @@ -235,10 +236,14 @@ def __init__(
:param expected_name: Require the devices name to match the given expected name.
Can be used to prevent accidentally connecting to a different device if
IP passed as address but DHCP reassigned IP.
:param addresses: Optional list of IP addresses to connect to which takes
precedence over the address parameter. This is most commonly used when
the device has dual stack IPv4 and IPv6 addresses and you do not know
which one to connect to.
"""
self._debug_enabled = _LOGGER.isEnabledFor(logging.DEBUG)
self._params = ConnectionParams(
address=str(address),
addresses=addresses if addresses else [str(address)],
port=port,
password=password,
client_info=client_info,
Expand Down Expand Up @@ -274,17 +279,17 @@ def expected_name(self, value: str | None) -> None:

@property
def address(self) -> str:
return self._params.address
return self._params.addresses[0]

def _set_log_name(self) -> None:
"""Set the log name of the device."""
resolved_address: str | None = None
if self._connection and self._connection.resolved_addr_info:
resolved_address = self._connection.resolved_addr_info[0].sockaddr.address
connected_address: str | None = None
if self._connection and self._connection.connected_address:
connected_address = self._connection.connected_address
self.log_name = build_log_name(
self.cached_name,
self.address,
resolved_address,
self._params.addresses,
connected_address,
)
if self._connection:
self._connection.set_log_name(self.log_name)
Expand Down Expand Up @@ -328,8 +333,8 @@ async def start_connection(
self.log_name,
)
await self._execute_connection_coro(self._connection.start_connection())
# If we resolved the address, we should set the log name now
if self._connection.resolved_addr_info:
# If we connected, we should set the log name now
if self._connection.connected_address:
self._set_log_name()

async def finish_connection(
Expand Down
4 changes: 2 additions & 2 deletions aioesphomeapi/connection.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ cdef object _handle_complex_message

@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
cdef public list addresses
cdef public object port
cdef public object password
cdef public object client_info
Expand Down Expand Up @@ -108,7 +108,7 @@ cdef class APIConnection:
cdef bint _handshake_complete
cdef bint _debug_enabled
cdef public str received_name
cdef public object resolved_addr_info
cdef public str connected_address

cpdef void send_message(self, object msg)

Expand Down
31 changes: 15 additions & 16 deletions aioesphomeapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

@dataclass
class ConnectionParams:
address: str
addresses: list[str]
port: int
password: str | None
client_info: str
Expand Down Expand Up @@ -207,7 +207,7 @@ class APIConnection:
"_handshake_complete",
"_debug_enabled",
"received_name",
"resolved_addr_info",
"connected_address",
)

def __init__(
Expand All @@ -230,7 +230,7 @@ def __init__(
# Message handlers currently subscribed to incoming messages
self._message_handlers: dict[Any, set[Callable[[message.Message], None]]] = {}
# The friendly name to show for this connection in the logs
self.log_name = log_name or params.address
self.log_name = log_name or ",".join(params.addresses)

# futures currently subscribed to exceptions in the read task
self._read_exception_futures: set[asyncio.Future[None]] = set()
Expand All @@ -251,7 +251,7 @@ def __init__(
self._handshake_complete = False
self._debug_enabled = debug_enabled
self.received_name: str = ""
self.resolved_addr_info: list[hr.AddrInfo] = []
self.connected_address: str | None = None

def set_log_name(self, name: str) -> None:
"""Set the friendly log name for this connection."""
Expand Down Expand Up @@ -325,7 +325,7 @@ async def _connect_resolve_host(self) -> list[hr.AddrInfo]:
try:
async with asyncio_timeout(RESOLVE_TIMEOUT):
return await hr.async_resolve_host(
self._params.address,
self._params.addresses,
self._params.port,
self._params.zeroconf_manager,
)
Expand All @@ -338,19 +338,17 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
"""Step 2 in connect process: connect the socket."""
if self._debug_enabled:
_LOGGER.debug(
"%s: Connecting to %s:%s (%s)",
"%s: Connecting to %s",
self.log_name,
self._params.address,
self._params.port,
addrs,
", ".join(str(addr.sockaddr) for addr in addrs),
)

addr_infos: list[aiohappyeyeballs.AddrInfoType] = [
(
addr.family,
addr.type,
addr.proto,
self._params.address,
"",
astuple(addr.sockaddr),
)
for addr in addrs
Expand All @@ -361,9 +359,11 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
while addr_infos:
try:
async with asyncio_timeout(TCP_CONNECT_TIMEOUT):
# Devices are likely on the local network so we
# only use a 100ms happy eyeballs delay
sock = await aiohappyeyeballs.start_connection(
addr_infos,
happy_eyeballs_delay=0.25,
happy_eyeballs_delay=0.1,
interleave=interleave,
loop=self._loop,
)
Expand All @@ -387,14 +387,14 @@ async def _connect_socket_connect(self, addrs: list[hr.AddrInfo]) -> None:
# Try to reduce the pressure on esphome device as it measures
# ram in bytes and we measure ram in megabytes.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, BUFFER_SIZE)
self.connected_address = sock.getpeername()[0]

if self._debug_enabled:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
"%s: Opened socket to %s:%s",
self.log_name,
self._params.address,
self.connected_address,
self._params.port,
addrs,
)

async def _connect_init_frame_helper(self) -> None:
Expand Down Expand Up @@ -567,8 +567,7 @@ def _async_pong_not_received(self) -> None:

async def _do_connect(self) -> None:
"""Do the actual connect process."""
self.resolved_addr_info = await self._connect_resolve_host()
await self._connect_socket_connect(self.resolved_addr_info)
await self._connect_socket_connect(await self._connect_resolve_host())

async def start_connection(self) -> None:
"""Start the connection process.
Expand Down
48 changes: 29 additions & 19 deletions aioesphomeapi/host_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import asyncio
import contextlib
import logging
import socket
from dataclasses import dataclass
Expand Down Expand Up @@ -181,35 +180,46 @@ def _async_ip_address_to_addrs(


async def async_resolve_host(
host: str,
hosts: list[str],
port: int,
zeroconf_manager: ZeroconfManager | None = None,
) -> list[AddrInfo]:
addrs: list[AddrInfo] = []

zc_error = None
if host_is_name_part(host) or address_is_local(host):
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
zc_error: Exception | None = None

for host in hosts:
host_addrs: list[AddrInfo] = []
host_is_local_name = host_is_name_part(host) or address_is_local(host)

if host_is_local_name:
name = host.partition(".")[0]
try:
host_addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
)
)
)
except ResolveAPIError as err:
zc_error = err
except ResolveAPIError as err:
zc_error = err

else:
with contextlib.suppress(ValueError):
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
if not host_is_local_name:
try:
host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
pass

if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
if not host_addrs:
host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))

addrs.extend(host_addrs)

if not addrs:
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
raise ResolveAPIError(
f"Could not resolve host {hosts} - got no results from OS"
)

return addrs
15 changes: 11 additions & 4 deletions aioesphomeapi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,18 @@ def address_is_local(address: str) -> bool:
return address.removesuffix(".").endswith(".local")


def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
def build_log_name(
name: str | None, addresses: list[str], connected_address: str | None
) -> str:
"""Return a log name for a connection."""
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
preferred_address = resolved_address or address
preferred_address = connected_address
for address in addresses:
if not name and address_is_local(address) or host_is_name_part(address):
name = address.partition(".")[0]
elif not preferred_address:
preferred_address = address
if not preferred_address:
return name or addresses[0]
if (
name
and name != preferred_address
Expand Down
21 changes: 8 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import replace
from functools import partial
from typing import Callable
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, create_autospec, patch

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -50,12 +50,6 @@ def resolve_host():
yield func


@pytest.fixture
def socket_socket():
with patch("socket.socket") as func:
yield func


@pytest.fixture
def patchable_api_client() -> APIClient:
class PatchableAPIClient(APIClient):
Expand All @@ -71,7 +65,7 @@ class PatchableAPIClient(APIClient):

def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",
addresses=["fake.address"],
port=6052,
password=None,
client_info="Tests client",
Expand Down Expand Up @@ -119,7 +113,11 @@ def conn_with_expected_name(connection_params: ConnectionParams) -> APIConnectio
@pytest.fixture()
def aiohappyeyeballs_start_connection():
with patch("aioesphomeapi.connection.aiohappyeyeballs.start_connection") as func:
func.return_value = MagicMock(type=socket.SOCK_STREAM)
mock_socket = create_autospec(socket.socket, spec_set=True, instance=True)
mock_socket.type = socket.SOCK_STREAM
mock_socket.fileno.return_value = 1
mock_socket.getpeername.return_value = ("10.0.0.512", 323)
func.return_value = mock_socket
yield func


Expand All @@ -139,7 +137,6 @@ def _create_mock_transport_protocol(
async def plaintext_connect_task_no_login(
conn: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
Expand All @@ -161,7 +158,6 @@ async def plaintext_connect_task_no_login(
async def plaintext_connect_task_no_login_with_expected_name(
conn_with_expected_name: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
Expand All @@ -184,7 +180,6 @@ async def plaintext_connect_task_no_login_with_expected_name(
async def plaintext_connect_task_with_login(
conn_with_password: APIConnection,
resolve_host,
socket_socket,
event_loop,
aiohappyeyeballs_start_connection,
) -> tuple[APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task]:
Expand All @@ -203,7 +198,7 @@ async def plaintext_connect_task_with_login(

@pytest_asyncio.fixture(name="api_client")
async def api_client(
resolve_host, socket_socket, event_loop, aiohappyeyeballs_start_connection
resolve_host, event_loop, aiohappyeyeballs_start_connection
) -> tuple[APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper]:
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
Expand Down
7 changes: 5 additions & 2 deletions tests/test__frame_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def mock_write_frame(self, frame: bytes) -> None:
),
],
)
def test_plaintext_frame_helper(
@pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
Expand Down Expand Up @@ -592,7 +593,9 @@ def _writer(data: bytes):


@pytest.mark.asyncio
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
async def test_init_plaintext_with_wrong_preamble(
conn: APIConnection, aiohappyeyeballs_start_connection
):
loop = asyncio.get_event_loop()
protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection:
Expand Down
Loading

0 comments on commit de1d084

Please sign in to comment.