From ebd669302b2ae3f5fc6e0e845f29eddfd7090b21 Mon Sep 17 00:00:00 2001 From: jan iversen Date: Tue, 6 Jun 2023 10:39:53 +0200 Subject: [PATCH] transport_connect -> bool. --- API_changes.rst | 5 +++++ pymodbus/client/serial.py | 2 +- pymodbus/client/tcp.py | 2 +- pymodbus/client/tls.py | 2 +- pymodbus/client/udp.py | 2 +- pymodbus/transport/transport.py | 14 ++++++-------- test/transport/conftest.py | 12 ++++++++++-- test/transport/test_basic.py | 26 ++++++++++++-------------- test/transport/test_comm.py | 18 +++++++++--------- test/transport/test_data.py | 2 +- test/transport/test_reconnect.py | 2 +- 11 files changed, 48 insertions(+), 39 deletions(-) diff --git a/API_changes.rst b/API_changes.rst index 6ff56228c..3b95e708c 100644 --- a/API_changes.rst +++ b/API_changes.rst @@ -2,6 +2,11 @@ PyModbus - API changes. ======================= +------------- +Version 3.4.0 +------------- +- ModbusClient .connect() returns True/False (connected or not) + ------------- Version 3.3.1 ------------- diff --git a/pymodbus/client/serial.py b/pymodbus/client/serial.py index 901b662a0..5f56a7a64 100644 --- a/pymodbus/client/serial.py +++ b/pymodbus/client/serial.py @@ -74,7 +74,7 @@ def connected(self): """Connect internal.""" return self.transport is not None - async def connect(self): + async def connect(self) -> bool: """Connect Async client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again # so this instance will work diff --git a/pymodbus/client/tcp.py b/pymodbus/client/tcp.py index 304bf4017..64efaf3f2 100644 --- a/pymodbus/client/tcp.py +++ b/pymodbus/client/tcp.py @@ -58,7 +58,7 @@ def __init__( else: self.setup_tcp(False, host, port) - async def connect(self): + async def connect(self) -> bool: """Initiate connection to start client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again diff --git a/pymodbus/client/tls.py b/pymodbus/client/tls.py index f724e3946..ef6a26f6c 100644 --- a/pymodbus/client/tls.py +++ b/pymodbus/client/tls.py @@ -93,7 +93,7 @@ def __init__( False, host, port, sslctx, certfile, keyfile, password, server_hostname ) - async def connect(self): + async def connect(self) -> bool: """Initiate connection to start client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again diff --git a/pymodbus/client/udp.py b/pymodbus/client/udp.py index 3b103c73f..5320aea16 100644 --- a/pymodbus/client/udp.py +++ b/pymodbus/client/udp.py @@ -61,7 +61,7 @@ def connected(self): """Return true if connected.""" return self.transport is not None - async def connect(self): + async def connect(self) -> bool: """Start reconnecting asynchronous udp client. :meta private: diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index dddf70ca8..aff2ec0a1 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -6,7 +6,6 @@ import asyncio import ssl import sys -from contextlib import suppress from dataclasses import dataclass from typing import Any, Callable, Coroutine @@ -97,8 +96,6 @@ def __init__( self.transport: asyncio.BaseTransport | asyncio.Server = None self.protocol: asyncio.BaseProtocol = None self.loop: asyncio.AbstractEventLoop = None - with suppress(RuntimeError): - self.loop = asyncio.get_running_loop() self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None @@ -264,7 +261,7 @@ def setup_serial( timeout=self.comm_params.timeout_connect, ) - async def transport_connect(self): + async def transport_connect(self) -> bool: """Handle generic connect and call on to specific transport connect.""" Log.debug("Connecting {}", self.comm_params.comm_name) if not self.loop: @@ -281,7 +278,8 @@ async def transport_connect(self): ) as exc: Log.warning("Failed to connect {}", exc) self.close(reconnect=True) - return self.transport, self.protocol + return False + return bool(self.transport) async def transport_listen(self): """Handle generic listen and call on to specific transport listen.""" @@ -383,15 +381,15 @@ async def reconnect_connect(self): """Handle reconnect as a task.""" try: self.reconnect_delay_current = self.comm_params.reconnect_delay - transport = None - while not transport: + while True: Log.debug( "Wait {} {} ms before reconnecting.", self.comm_params.comm_name, self.reconnect_delay_current * 1000, ) await asyncio.sleep(self.reconnect_delay_current) - transport, _protocol = await self.transport_connect() + if await self.transport_connect(): + break self.reconnect_delay_current = min( 2 * self.reconnect_delay_current, self.comm_params.reconnect_delay_max, diff --git a/test/transport/conftest.py b/test/transport/conftest.py index 4140fad59..08d4b4c90 100644 --- a/test/transport/conftest.py +++ b/test/transport/conftest.py @@ -1,5 +1,7 @@ """Test transport.""" +import asyncio import os +from contextlib import suppress from dataclasses import dataclass from unittest import mock @@ -63,7 +65,7 @@ def prepare_testparams(): @pytest.fixture(name="transport") async def prepare_transport(): """Prepare transport object.""" - return Transport( + transport = Transport( BaseParams.comm_name, BaseParams.reconnect_delay, BaseParams.reconnect_delay_max, @@ -72,12 +74,15 @@ async def prepare_transport(): mock.Mock(name="cb_connection_lost"), mock.Mock(name="cb_handle_data", return_value=0), ) + with suppress(RuntimeError): + transport.loop = asyncio.get_running_loop() + return transport @pytest_asyncio.fixture(name="transport_server") async def prepare_transport_server(): """Prepare transport object.""" - return Transport( + transport = Transport( BaseParams.comm_name, BaseParams.reconnect_delay, BaseParams.reconnect_delay_max, @@ -86,3 +91,6 @@ async def prepare_transport_server(): mock.Mock(name="cb_connection_lost"), mock.Mock(name="cb_handle_data", return_value=0), ) + with suppress(RuntimeError): + transport.loop = asyncio.get_running_loop() + return transport diff --git a/test/transport/test_basic.py b/test/transport/test_basic.py index ad8816eae..8d2cfea1c 100644 --- a/test/transport/test_basic.py +++ b/test/transport/test_basic.py @@ -137,13 +137,11 @@ async def test_handle_listen(self, transport): async def test_reconnect_connect(self, transport): """Test handle_listen().""" transport.comm_params.reconnect_delay = 0.01 - transport.transport_connect = mock.AsyncMock( - side_effect=[(None, None), (117, 118)] - ) + transport.transport_connect = mock.AsyncMock(side_effect=[False, True]) await transport.reconnect_connect() assert ( transport.reconnect_delay_current - == transport.comm_params.reconnect_delay * 4 + == transport.comm_params.reconnect_delay * 2 ) assert not transport.reconnect_task transport.transport_connect = mock.AsyncMock( @@ -183,11 +181,11 @@ async def test_connect(self, params, transport): mocker = mock.AsyncMock() transport.loop.create_unix_connection = mocker mocker.side_effect = FileNotFoundError("testing") - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() mocker.side_effect = None mocker.return_value = (mock.Mock(), mock.Mock()) - assert mocker.return_value == await transport.transport_connect() + assert await transport.transport_connect() transport.close() async def test_listen(self, params, transport): @@ -223,11 +221,11 @@ async def test_connect(self, params, transport): mocker = mock.AsyncMock() transport.loop.create_connection = mocker mocker.side_effect = asyncio.TimeoutError("testing") - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() mocker.side_effect = None mocker.return_value = (mock.Mock(), mock.Mock()) - assert mocker.return_value == await transport.transport_connect() + assert await transport.transport_connect() transport.close() async def test_listen(self, params, transport): @@ -285,11 +283,11 @@ async def test_connect(self, params, transport): mocker = mock.AsyncMock() transport.loop.create_connection = mocker mocker.side_effect = asyncio.TimeoutError("testing") - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() mocker.side_effect = None mocker.return_value = (mock.Mock(), mock.Mock()) - assert mocker.return_value == await transport.transport_connect() + assert await transport.transport_connect() transport.close() async def test_listen(self, params, transport): @@ -334,11 +332,11 @@ async def test_connect(self, params, transport): mocker = mock.AsyncMock() transport.loop.create_datagram_endpoint = mocker mocker.side_effect = asyncio.TimeoutError("testing") - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() mocker.side_effect = None mocker.return_value = (mock.Mock(), mock.Mock()) - assert mocker.return_value == await transport.transport_connect() + assert await transport.transport_connect() transport.close() async def test_listen(self, params, transport): @@ -393,11 +391,11 @@ async def test_connect(self, params, transport): "pymodbus.transport.transport.create_serial_connection", new=mocker ): mocker.side_effect = asyncio.TimeoutError("testing") - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() mocker.side_effect = None mocker.return_value = (mock.Mock(), mock.Mock()) - assert mocker.return_value == await transport.transport_connect() + assert await transport.transport_connect() transport.close() async def test_listen(self, params, transport): diff --git a/test/transport/test_comm.py b/test/transport/test_comm.py index 5ea591916..18d24f22c 100644 --- a/test/transport/test_comm.py +++ b/test/transport/test_comm.py @@ -33,7 +33,7 @@ async def test_connect(self, transport, domain_socket): """Test connect_unix().""" transport.setup_unix(False, domain_socket) start = time.time() - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() delta = time.time() - start assert delta < transport.comm_params.timeout_connect * 1.2 transport.close() @@ -56,7 +56,7 @@ async def test_connected(self, transport, transport_server, domain_socket): await transport_server.transport_listen() transport.setup_unix(False, domain_socket) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() transport.close() transport_server.close() @@ -69,7 +69,7 @@ async def test_connect(self, transport, domain_host): """Test connect_tcp().""" transport.setup_tcp(False, domain_host, BASE_PORT + 1) start = time.time() - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() delta = time.time() - start assert delta < transport.comm_params.timeout_connect * 1.2 transport.close() @@ -92,7 +92,7 @@ async def test_connected(self, transport, transport_server, domain_host): server = await transport_server.transport_listen() assert server transport.setup_tcp(False, domain_host, BASE_PORT + 3) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() transport.close() transport_server.close() server.close() @@ -115,7 +115,7 @@ async def test_connect(self, transport, params, domain_host): "localhost", ) start = time.time() - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() delta = time.time() - start assert delta < transport.comm_params.timeout_connect * 1.2 transport.close() @@ -157,7 +157,7 @@ async def test_connected(self, transport, transport_server, params, domain_host) assert server transport.setup_tcp(False, domain_host, BASE_PORT + 7) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() transport.close() transport_server.close() server.close() @@ -188,7 +188,7 @@ async def test_connected(self, transport, transport_server, domain_host): server = await transport_server.transport_listen() assert server transport.setup_udp(False, domain_host, BASE_PORT + 11) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() transport.close() transport_server.close() server.close() @@ -212,7 +212,7 @@ async def test_connect(self, transport, positive): 2, ) start = time.time() - assert await transport.transport_connect() == (None, None) + assert not await transport.transport_connect() delta = time.time() - start assert delta < transport.comm_params.timeout_connect * 1.2 transport.close() @@ -247,7 +247,7 @@ async def test_connected(self, transport, transport_server): "E", 2, ) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() transport.close() transport_server.close() server.close() diff --git a/test/transport/test_data.py b/test/transport/test_data.py index a48333f26..0d05f4592 100644 --- a/test/transport/test_data.py +++ b/test/transport/test_data.py @@ -15,7 +15,7 @@ async def test_client_send(self, transport, transport_server): assert transport_server.transport transport.setup_tcp(False, "localhost", BASE_PORT + 1) - assert await transport.transport_connect() != (None, None) + assert await transport.transport_connect() await transport.send(b"ABC") await asyncio.sleep(2) assert transport_server.recv_buffer == b"ABC" diff --git a/test/transport/test_reconnect.py b/test/transport/test_reconnect.py index 4a7c51c5d..f0027a3a3 100644 --- a/test/transport/test_reconnect.py +++ b/test/transport/test_reconnect.py @@ -61,6 +61,6 @@ async def test_reconnect_call_ok(self, transport, commparams): transport.connection_lost(RuntimeError("Connection lost")) await asyncio.sleep(transport.reconnect_delay_current * 1.8) assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 + assert transport.reconnect_delay_current == commparams.reconnect_delay assert not transport.reconnect_task transport.close()