Skip to content

Commit

Permalink
transport_connect -> bool.
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Jun 6, 2023
1 parent 27cc16d commit ebd6693
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 39 deletions.
5 changes: 5 additions & 0 deletions API_changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
PyModbus - API changes.
=======================

-------------
Version 3.4.0
-------------
- Modbus<x>Client .connect() returns True/False (connected or not)

-------------
Version 3.3.1
-------------
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def connected(self):
"""Connect internal."""
return self.transport is not None

async def connect(self):
async def connect(self) -> bool:
"""Connect Async client."""
# if reconnect_delay_current was set to 0 by close(), we need to set it back again
# so this instance will work
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
else:
self.setup_tcp(False, host, port)

async def connect(self):
async def connect(self) -> bool:
"""Initiate connection to start client."""

# if reconnect_delay_current was set to 0 by close(), we need to set it back again
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
False, host, port, sslctx, certfile, keyfile, password, server_hostname
)

async def connect(self):
async def connect(self) -> bool:
"""Initiate connection to start client."""

# if reconnect_delay_current was set to 0 by close(), we need to set it back again
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def connected(self):
"""Return true if connected."""
return self.transport is not None

async def connect(self):
async def connect(self) -> bool:
"""Start reconnecting asynchronous udp client.
:meta private:
Expand Down
14 changes: 6 additions & 8 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import asyncio
import ssl
import sys
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Callable, Coroutine

Expand Down Expand Up @@ -97,8 +96,6 @@ def __init__(
self.transport: asyncio.BaseTransport | asyncio.Server = None
self.protocol: asyncio.BaseProtocol = None
self.loop: asyncio.AbstractEventLoop = None
with suppress(RuntimeError):
self.loop = asyncio.get_running_loop()
self.reconnect_task: asyncio.Task = None
self.recv_buffer: bytes = b""
self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None
Expand Down Expand Up @@ -264,7 +261,7 @@ def setup_serial(
timeout=self.comm_params.timeout_connect,
)

async def transport_connect(self):
async def transport_connect(self) -> bool:
"""Handle generic connect and call on to specific transport connect."""
Log.debug("Connecting {}", self.comm_params.comm_name)
if not self.loop:
Expand All @@ -281,7 +278,8 @@ async def transport_connect(self):
) as exc:
Log.warning("Failed to connect {}", exc)
self.close(reconnect=True)
return self.transport, self.protocol
return False
return bool(self.transport)

async def transport_listen(self):
"""Handle generic listen and call on to specific transport listen."""
Expand Down Expand Up @@ -383,15 +381,15 @@ async def reconnect_connect(self):
"""Handle reconnect as a task."""
try:
self.reconnect_delay_current = self.comm_params.reconnect_delay
transport = None
while not transport:
while True:
Log.debug(
"Wait {} {} ms before reconnecting.",
self.comm_params.comm_name,
self.reconnect_delay_current * 1000,
)
await asyncio.sleep(self.reconnect_delay_current)
transport, _protocol = await self.transport_connect()
if await self.transport_connect():
break
self.reconnect_delay_current = min(
2 * self.reconnect_delay_current,
self.comm_params.reconnect_delay_max,
Expand Down
12 changes: 10 additions & 2 deletions test/transport/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test transport."""
import asyncio
import os
from contextlib import suppress
from dataclasses import dataclass
from unittest import mock

Expand Down Expand Up @@ -63,7 +65,7 @@ def prepare_testparams():
@pytest.fixture(name="transport")
async def prepare_transport():
"""Prepare transport object."""
return Transport(
transport = Transport(
BaseParams.comm_name,
BaseParams.reconnect_delay,
BaseParams.reconnect_delay_max,
Expand All @@ -72,12 +74,15 @@ async def prepare_transport():
mock.Mock(name="cb_connection_lost"),
mock.Mock(name="cb_handle_data", return_value=0),
)
with suppress(RuntimeError):
transport.loop = asyncio.get_running_loop()
return transport


@pytest_asyncio.fixture(name="transport_server")
async def prepare_transport_server():
"""Prepare transport object."""
return Transport(
transport = Transport(
BaseParams.comm_name,
BaseParams.reconnect_delay,
BaseParams.reconnect_delay_max,
Expand All @@ -86,3 +91,6 @@ async def prepare_transport_server():
mock.Mock(name="cb_connection_lost"),
mock.Mock(name="cb_handle_data", return_value=0),
)
with suppress(RuntimeError):
transport.loop = asyncio.get_running_loop()
return transport
26 changes: 12 additions & 14 deletions test/transport/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,11 @@ async def test_handle_listen(self, transport):
async def test_reconnect_connect(self, transport):
"""Test handle_listen()."""
transport.comm_params.reconnect_delay = 0.01
transport.transport_connect = mock.AsyncMock(
side_effect=[(None, None), (117, 118)]
)
transport.transport_connect = mock.AsyncMock(side_effect=[False, True])
await transport.reconnect_connect()
assert (
transport.reconnect_delay_current
== transport.comm_params.reconnect_delay * 4
== transport.comm_params.reconnect_delay * 2
)
assert not transport.reconnect_task
transport.transport_connect = mock.AsyncMock(
Expand Down Expand Up @@ -183,11 +181,11 @@ async def test_connect(self, params, transport):
mocker = mock.AsyncMock()
transport.loop.create_unix_connection = mocker
mocker.side_effect = FileNotFoundError("testing")
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
mocker.side_effect = None

mocker.return_value = (mock.Mock(), mock.Mock())
assert mocker.return_value == await transport.transport_connect()
assert await transport.transport_connect()
transport.close()

async def test_listen(self, params, transport):
Expand Down Expand Up @@ -223,11 +221,11 @@ async def test_connect(self, params, transport):
mocker = mock.AsyncMock()
transport.loop.create_connection = mocker
mocker.side_effect = asyncio.TimeoutError("testing")
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
mocker.side_effect = None

mocker.return_value = (mock.Mock(), mock.Mock())
assert mocker.return_value == await transport.transport_connect()
assert await transport.transport_connect()
transport.close()

async def test_listen(self, params, transport):
Expand Down Expand Up @@ -285,11 +283,11 @@ async def test_connect(self, params, transport):
mocker = mock.AsyncMock()
transport.loop.create_connection = mocker
mocker.side_effect = asyncio.TimeoutError("testing")
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
mocker.side_effect = None

mocker.return_value = (mock.Mock(), mock.Mock())
assert mocker.return_value == await transport.transport_connect()
assert await transport.transport_connect()
transport.close()

async def test_listen(self, params, transport):
Expand Down Expand Up @@ -334,11 +332,11 @@ async def test_connect(self, params, transport):
mocker = mock.AsyncMock()
transport.loop.create_datagram_endpoint = mocker
mocker.side_effect = asyncio.TimeoutError("testing")
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
mocker.side_effect = None

mocker.return_value = (mock.Mock(), mock.Mock())
assert mocker.return_value == await transport.transport_connect()
assert await transport.transport_connect()
transport.close()

async def test_listen(self, params, transport):
Expand Down Expand Up @@ -393,11 +391,11 @@ async def test_connect(self, params, transport):
"pymodbus.transport.transport.create_serial_connection", new=mocker
):
mocker.side_effect = asyncio.TimeoutError("testing")
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
mocker.side_effect = None

mocker.return_value = (mock.Mock(), mock.Mock())
assert mocker.return_value == await transport.transport_connect()
assert await transport.transport_connect()
transport.close()

async def test_listen(self, params, transport):
Expand Down
18 changes: 9 additions & 9 deletions test/transport/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def test_connect(self, transport, domain_socket):
"""Test connect_unix()."""
transport.setup_unix(False, domain_socket)
start = time.time()
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
delta = time.time() - start
assert delta < transport.comm_params.timeout_connect * 1.2
transport.close()
Expand All @@ -56,7 +56,7 @@ async def test_connected(self, transport, transport_server, domain_socket):
await transport_server.transport_listen()

transport.setup_unix(False, domain_socket)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
transport.close()
transport_server.close()

Expand All @@ -69,7 +69,7 @@ async def test_connect(self, transport, domain_host):
"""Test connect_tcp()."""
transport.setup_tcp(False, domain_host, BASE_PORT + 1)
start = time.time()
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
delta = time.time() - start
assert delta < transport.comm_params.timeout_connect * 1.2
transport.close()
Expand All @@ -92,7 +92,7 @@ async def test_connected(self, transport, transport_server, domain_host):
server = await transport_server.transport_listen()
assert server
transport.setup_tcp(False, domain_host, BASE_PORT + 3)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
transport.close()
transport_server.close()
server.close()
Expand All @@ -115,7 +115,7 @@ async def test_connect(self, transport, params, domain_host):
"localhost",
)
start = time.time()
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
delta = time.time() - start
assert delta < transport.comm_params.timeout_connect * 1.2
transport.close()
Expand Down Expand Up @@ -157,7 +157,7 @@ async def test_connected(self, transport, transport_server, params, domain_host)
assert server

transport.setup_tcp(False, domain_host, BASE_PORT + 7)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
transport.close()
transport_server.close()
server.close()
Expand Down Expand Up @@ -188,7 +188,7 @@ async def test_connected(self, transport, transport_server, domain_host):
server = await transport_server.transport_listen()
assert server
transport.setup_udp(False, domain_host, BASE_PORT + 11)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
transport.close()
transport_server.close()
server.close()
Expand All @@ -212,7 +212,7 @@ async def test_connect(self, transport, positive):
2,
)
start = time.time()
assert await transport.transport_connect() == (None, None)
assert not await transport.transport_connect()
delta = time.time() - start
assert delta < transport.comm_params.timeout_connect * 1.2
transport.close()
Expand Down Expand Up @@ -247,7 +247,7 @@ async def test_connected(self, transport, transport_server):
"E",
2,
)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
transport.close()
transport_server.close()
server.close()
2 changes: 1 addition & 1 deletion test/transport/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def test_client_send(self, transport, transport_server):
assert transport_server.transport

transport.setup_tcp(False, "localhost", BASE_PORT + 1)
assert await transport.transport_connect() != (None, None)
assert await transport.transport_connect()
await transport.send(b"ABC")
await asyncio.sleep(2)
assert transport_server.recv_buffer == b"ABC"
Expand Down
2 changes: 1 addition & 1 deletion test/transport/test_reconnect.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ async def test_reconnect_call_ok(self, transport, commparams):
transport.connection_lost(RuntimeError("Connection lost"))
await asyncio.sleep(transport.reconnect_delay_current * 1.8)
assert mocker.call_count == 1
assert transport.reconnect_delay_current == commparams.reconnect_delay * 2
assert transport.reconnect_delay_current == commparams.reconnect_delay
assert not transport.reconnect_task
transport.close()

0 comments on commit ebd6693

Please sign in to comment.