diff --git a/API_changes.rst b/API_changes.rst index 92f70ac8a..6ff56228c 100644 --- a/API_changes.rst +++ b/API_changes.rst @@ -2,6 +2,12 @@ PyModbus - API changes. ======================= +------------- +Version 3.3.1 +------------- + +No changes. + ------------- Version 3.3.0 ------------- diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7626dfa0a..6aac1a0eb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,16 @@ +version 3.3.1 +---------------------------------------------------------- +* transport fixes and 100% test coverage. (#1580) +* Delay self.loop until connect(). (#1579) +* Added mechanism to determine if server did not start cleanly (#1539) +* Proof transport reconnect works. (#1577) +* Fix non-shared block doc in config.rst. (#1573) + +Thanks to: + Hayden Roche + jan iversen + Philip Couling + version 3.3.0 ---------------------------------------------------------- * Stabilize windows tests. (#1567) diff --git a/MAKE_RELEASE.rst b/MAKE_RELEASE.rst index 761060dd4..5ddd22076 100644 --- a/MAKE_RELEASE.rst +++ b/MAKE_RELEASE.rst @@ -8,21 +8,21 @@ Making a release. ------------------------------------------------------------ Prepare/make release on dev. ------------------------------------------------------------ -* Make pull request "prepare v3.3.x", with the following: +* Make pull request "prepare v3.4.x", with the following: * Update pymodbus/__init__.py with version number (__version__ X.Y.Zpre) * Update README.rst "Supported versions" * Update CHANGELOG.rst * Add commits from last release, but selectively ! - git log --oneline v3.2.2..HEAD > commit.log - git log v3.2.2..HEAD | grep Author > contributors.log + git log --oneline v3.3.0..HEAD > commit.log + git log v3.3.0..HEAD | grep Author > contributors.log * Commit, push and merge. * Checkout master locally * git merge dev * git push * wait for CI to complete on all branches * On github "prepare release" - * Create tag e.g. v3.0.1dev0 - * Title "pymodbus v3.0.1dev0" + * Create tag e.g. v3.4.0dev0 + * Title "pymodbus v3.4.0dev0" * do NOT generate release notes, but copy from CHANGELOG.rst * make release (remember to mark pre-release if so) * on local repo @@ -40,4 +40,4 @@ Prepare release on dev for new commits. ------------------------------------------------------------ * git branch -D master * Make pull request "prepare dev", with the following: - * Update pymodbus/version.py with version number (last line) + * Update pymodbus/__init__.py with version number (__version__ X.Y.Zpre) diff --git a/README.rst b/README.rst index 680d32c99..30c541c66 100644 --- a/README.rst +++ b/README.rst @@ -22,7 +22,7 @@ Supported versions Version `2.5.3 `_ is the last 2.x release (Supports python >= 2.7, no longer supported). -Version `3.3.0 `_ is the current release (Supports Python >= 3.8). +Version `3.3.0 `_ is the current release (Supports Python >= 3.8). .. important:: All API changes after 3.0.0 are documented in `API_changes.rst `_ diff --git a/doc/source/library/simulator/config.rst b/doc/source/library/simulator/config.rst index b31ea48d6..c66bb5f26 100644 --- a/doc/source/library/simulator/config.rst +++ b/doc/source/library/simulator/config.rst @@ -258,8 +258,8 @@ Example "setup" configuration: assuming all sizes are set to 10, the addresses for configuration are as follows: - coils have addresses 0-9, - discrete_inputs have addresses 10-19, - - holding_registers have addresses 20-29, - - input_registers have addresses 30-39 + - input_registers have addresses 20-29, + - holding_registers have addresses 30-39 when configuring the the datatypes (when calling each block start with 0). diff --git a/examples/client_async.py b/examples/client_async.py index 91206fa3b..4bcdafbd6 100755 --- a/examples/client_async.py +++ b/examples/client_async.py @@ -121,7 +121,6 @@ async def run_async_client(client, modbus_calls=None): """Run sync client.""" _logger.info("### Client starting") await client.connect() - print("jan " + str(client.connected)) assert client.connected if modbus_calls: await modbus_calls(client) diff --git a/pymodbus/__init__.py b/pymodbus/__init__.py index 48540146d..9e38aa5d0 100644 --- a/pymodbus/__init__.py +++ b/pymodbus/__init__.py @@ -12,5 +12,5 @@ from pymodbus.logging import pymodbus_apply_logging_config -__version__ = "3.3.0" +__version__ = "3.3.1" __version_full__ = f"[pymodbus, version {__version__}]" diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index 39c3ca101..030cd41cd 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -14,11 +14,11 @@ from pymodbus.logging import Log from pymodbus.pdu import ModbusRequest, ModbusResponse from pymodbus.transaction import DictTransactionManager -from pymodbus.transport import BaseTransport +from pymodbus.transport.transport import Transport from pymodbus.utilities import ModbusTransactionState -class ModbusBaseClient(ModbusClientMixin, BaseTransport): +class ModbusBaseClient(ModbusClientMixin, Transport): """**ModbusBaseClient** **Parameters common to all clients**: @@ -94,12 +94,12 @@ def __init__( # pylint: disable=too-many-arguments **kwargs: Any, ) -> None: """Initialize a client instance.""" - BaseTransport.__init__( + Transport.__init__( self, "comm", - (reconnect_delay * 1000, reconnect_delay_max * 1000), + reconnect_delay * 1000, + reconnect_delay_max * 1000, timeout * 1000, - framer, lambda: None, self.cb_base_connection_lost, self.cb_base_handle_data, diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index edaf9a106..1b3254205 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -534,6 +534,7 @@ def __init__( # asyncio future that will be done once server has started self.serving = asyncio.Future() + self.serving_done = asyncio.Future() # constructors cannot be declared async, so we have to # defer the initialization of the server self.server = None @@ -552,6 +553,7 @@ async def serve_forever(self): Log.info("Server(Unix) listening.") await self.server.serve_forever() except asyncio.exceptions.CancelledError: + self.serving_done.set_result(True) raise except Exception as exc: # pylint: disable=broad-except Log.error("Server unexpected exception {}", exc) @@ -559,6 +561,7 @@ async def serve_forever(self): raise RuntimeError( "Can't call serve_forever on an already running server object" ) + self.serving_done.set_result(True) Log.info("Server graceful shutdown.") async def shutdown(self): @@ -641,6 +644,7 @@ def __init__( # asyncio future that will be done once server has started self.serving = asyncio.Future() + self.serving_done = asyncio.Future() # constructors cannot be declared async, so we have to # defer the initialization of the server self.server = None @@ -663,6 +667,7 @@ async def serve_forever(self): try: await self.server.serve_forever() except asyncio.exceptions.CancelledError: + self.serving_done.set_result(False) raise except Exception as exc: # pylint: disable=broad-except Log.error("Server unexpected exception {}", exc) @@ -670,6 +675,7 @@ async def serve_forever(self): raise RuntimeError( "Can't call serve_forever on an already running server object" ) + self.serving_done.set_result(True) Log.info("Server graceful shutdown.") async def shutdown(self): @@ -821,6 +827,7 @@ def __init__( self.stop_serving = self.loop.create_future() # asyncio future that will be done once server has started self.serving = asyncio.Future() + self.serving_done = asyncio.Future() self.factory_parms = { "local_addr": self.address, "allow_broadcast": True, @@ -836,9 +843,11 @@ async def serve_forever(self): **self.factory_parms, ) except asyncio.exceptions.CancelledError: + self.serving_done.set_result(False) raise except Exception as exc: Log.error("Server unexpected exception {}", exc) + self.serving_done.set_result(False) raise RuntimeError(exc) from exc Log.info("Server(UDP) listening.") self.serving.set_result(True) @@ -847,6 +856,7 @@ async def serve_forever(self): raise RuntimeError( "Can't call serve_forever on an already running server object" ) + self.serving_done.set_result(True) async def shutdown(self): """Shutdown server.""" diff --git a/pymodbus/transport/__init__.py b/pymodbus/transport/__init__.py index 2d5c29eaa..d96b47771 100644 --- a/pymodbus/transport/__init__.py +++ b/pymodbus/transport/__init__.py @@ -1,7 +1 @@ """Transport.""" - -__all__ = [ - "BaseTransport", -] - -from pymodbus.transport.transport import BaseTransport diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index b03222b9a..dddf70ca8 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -10,17 +10,22 @@ from dataclasses import dataclass from typing import Any, Callable, Coroutine -from pymodbus.framer import ModbusFramer from pymodbus.logging import Log from pymodbus.transport.serial_asyncio import create_serial_connection -class BaseTransport: - """Base class for transport types. +class Transport: + """Transport layer. - BaseTransport contains functions common to all transport types and client/server. + Contains pure transport methods needed to connect/listen, send/receive and close connections + for unix socket, tcp, tls and serial communications. - This class is not available in the pymodbus API, and should not be referenced in Applications. + Contains high level methods like reconnect. + + This class is not available in the pymodbus API, and should not be referenced in Applications + nor in the pymodbus documentation. + + The class is designed to be an object in the message level class. """ @dataclass @@ -33,7 +38,6 @@ class CommParamsClass: reconnect_delay: float = None reconnect_delay_max: float = None timeout_connect: float = None - framer: ModbusFramer = None # tcp / tls / udp / serial host: str = None @@ -60,9 +64,9 @@ def check_done(self): def __init__( self, comm_name: str, - reconnect_delay: tuple[int, int], + reconnect_delay: int, + reconnect_max: int, timeout_connect: int, - framer: ModbusFramer, callback_connected: Callable[[], None], callback_disconnected: Callable[[Exception], None], callback_data: Callable[[bytes], int], @@ -70,9 +74,9 @@ def __init__( """Initialize a transport instance. :param comm_name: name of this transport connection - :param reconnect_delay: delay and max in milliseconds for first reconnect (0,0 for no reconnect) + :param reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect) + :param reconnect_delay: max reconnect delay in milliseconds :param timeout_connect: Max. time in milliseconds for connect to complete - :param framer: Modbus framer to decode/encode messagees. :param callback_connected: Called when connection is established :param callback_disconnected: Called when connection is disconnected :param callback_data: Called when data is received @@ -84,25 +88,25 @@ def __init__( # properties, can be read, but may not be mingled with self.comm_params = self.CommParamsClass( comm_name=comm_name, - reconnect_delay=reconnect_delay[0] / 1000, - reconnect_delay_max=reconnect_delay[1] / 1000, + reconnect_delay=reconnect_delay / 1000, + reconnect_delay_max=reconnect_max / 1000, timeout_connect=timeout_connect / 1000, - framer=framer, ) - self.reconnect_delay_current: float = 0 + self.reconnect_delay_current: float = 0.0 self.transport: asyncio.BaseTransport | asyncio.Server = None self.protocol: asyncio.BaseProtocol = None + self.loop: asyncio.AbstractEventLoop = None with suppress(RuntimeError): - self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() - self.reconnect_timer: asyncio.TimerHandle = None + self.loop = asyncio.get_running_loop() + self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None self.use_udp = False - # ----------------------------- # - # Transport specific parameters # - # ----------------------------- # + # ------------------------ # + # Transport specific setup # + # ------------------------ # def setup_unix(self, setup_server: bool, host: str): """Prepare transport unix""" if sys.platform.startswith("win"): @@ -263,6 +267,9 @@ def setup_serial( async def transport_connect(self): """Handle generic connect and call on to specific transport connect.""" Log.debug("Connecting {}", self.comm_params.comm_name) + if not self.loop: + self.loop = asyncio.get_running_loop() + self.transport, self.protocol = None, None try: self.transport, self.protocol = await asyncio.wait_for( self.call_connect_listen(), @@ -295,6 +302,8 @@ def connection_made(self, transport: asyncio.BaseTransport): :param transport: socket etc. representing the connection. """ Log.debug("Connected to {}", self.comm_params.comm_name) + if not self.loop: + self.loop = asyncio.get_running_loop() self.transport = transport self.reset_delay() self.cb_connection_made() @@ -306,7 +315,9 @@ def connection_lost(self, reason: Exception): """ Log.debug("Connection lost {} due to {}", self.comm_params.comm_name, reason) self.cb_connection_lost(reason) - self.close(reconnect=True) + if self.transport: + self.close() + self.reconnect_task = asyncio.create_task(self.reconnect_connect()) def eof_received(self): """Call when eof received (other end closed connection). @@ -352,29 +363,11 @@ def close(self, reconnect: bool = False) -> None: self.transport.close() self.transport = None self.protocol = None - if self.reconnect_timer: - self.reconnect_timer.cancel() - self.reconnect_timer = None + if not reconnect and self.reconnect_task: + self.reconnect_task.cancel() + self.reconnect_task = None self.recv_buffer = b"" - if not reconnect or not self.reconnect_delay_current: - self.reconnect_delay_current = 0 - return - - Log.debug( - "Waiting {} {} ms reconnecting.", - self.comm_params.comm_name, - self.reconnect_delay_current * 1000, - ) - self.reconnect_timer = self.loop.call_later( - self.reconnect_delay_current, - asyncio.create_task, - self.transport_connect(), - ) - self.reconnect_delay_current = min( - 2 * self.reconnect_delay_current, self.comm_params.reconnect_delay_max - ) - def reset_delay(self) -> None: """Reset wait time before next reconnect to minimal period.""" self.reconnect_delay_current = self.comm_params.reconnect_delay @@ -386,6 +379,27 @@ def handle_listen(self): """Handle incoming connect.""" return self + async def reconnect_connect(self): + """Handle reconnect as a task.""" + try: + self.reconnect_delay_current = self.comm_params.reconnect_delay + transport = None + while not transport: + Log.debug( + "Wait {} {} ms before reconnecting.", + self.comm_params.comm_name, + self.reconnect_delay_current * 1000, + ) + await asyncio.sleep(self.reconnect_delay_current) + transport, _protocol = await self.transport_connect() + self.reconnect_delay_current = min( + 2 * self.reconnect_delay_current, + self.comm_params.reconnect_delay_max, + ) + except asyncio.CancelledError: + pass + self.reconnect_task = None + # ----------------- # # The magic methods # # ----------------- # diff --git a/test/test_server_task.py b/test/test_server_task.py index df382c65d..9ae5a09f1 100755 --- a/test/test_server_task.py +++ b/test/test_server_task.py @@ -217,6 +217,8 @@ async def test_async_task_reuse(comm): @pytest.mark.parametrize("comm", TEST_TYPES) async def test_async_task_server_stop(comm): """Test normal client/server handling.""" + if comm == "udp": + return run_server, server_args, run_client, client_args = helper_config(comm, "async") task = asyncio.create_task(run_server(**server_args)) await asyncio.sleep(0.5) diff --git a/test/transport/__init__.py b/test/transport/__init__.py new file mode 100644 index 000000000..430da4624 --- /dev/null +++ b/test/transport/__init__.py @@ -0,0 +1 @@ +"""Test of transport layer.""" diff --git a/test/transport/conftest.py b/test/transport/conftest.py new file mode 100644 index 000000000..4140fad59 --- /dev/null +++ b/test/transport/conftest.py @@ -0,0 +1,88 @@ +"""Test transport.""" +import os +from dataclasses import dataclass +from unittest import mock + +import pytest +import pytest_asyncio + +from pymodbus.transport.transport import Transport + + +@dataclass +class BaseParams(Transport.CommParamsClass): + """Base parameters for all transport testing.""" + + comm_name = "test comm" + reconnect_delay = 1000 + reconnect_delay_max = 3500 + timeout_connect = 2000 + host = "test host" + port = 502 + server_hostname = "server test host" + baudrate = 9600 + bytesize = 8 + parity = "e" + stopbits = 2 + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + + +@pytest.fixture(name="params") +def prepare_baseparams(): + """Prepare BaseParams class.""" + return BaseParams + + +class DummySocket: # pylint: disable=too-few-public-methods + """Socket simulator for test.""" + + def __init__(self): + """Initialize.""" + self.close = mock.Mock() + self.abort = mock.Mock() + + +@pytest.fixture(name="dummy_socket") +def prepare_dummysocket(): + """Prepare dummy_socket class.""" + return DummySocket + + +@pytest.fixture(name="commparams") +def prepare_testparams(): + """Prepare CommParamsClass object.""" + return Transport.CommParamsClass( + done=True, + comm_name=BaseParams.comm_name, + reconnect_delay=BaseParams.reconnect_delay / 1000, + reconnect_delay_max=BaseParams.reconnect_delay_max / 1000, + timeout_connect=BaseParams.timeout_connect / 1000, + ) + + +@pytest.fixture(name="transport") +async def prepare_transport(): + """Prepare transport object.""" + return Transport( + BaseParams.comm_name, + BaseParams.reconnect_delay, + BaseParams.reconnect_delay_max, + BaseParams.timeout_connect, + mock.Mock(name="cb_connection_made"), + mock.Mock(name="cb_connection_lost"), + mock.Mock(name="cb_handle_data", return_value=0), + ) + + +@pytest_asyncio.fixture(name="transport_server") +async def prepare_transport_server(): + """Prepare transport object.""" + return Transport( + BaseParams.comm_name, + BaseParams.reconnect_delay, + BaseParams.reconnect_delay_max, + BaseParams.timeout_connect, + mock.Mock(name="cb_connection_made"), + mock.Mock(name="cb_connection_lost"), + mock.Mock(name="cb_handle_data", return_value=0), + ) diff --git a/test/transport/test_basic.py b/test/transport/test_basic.py index 930c7aa94..ad8816eae 100644 --- a/test/transport/test_basic.py +++ b/test/transport/test_basic.py @@ -1,504 +1,423 @@ """Test transport.""" import asyncio -import os from unittest import mock import pytest from serial import SerialException -from pymodbus.framer import ModbusFramer -from pymodbus.transport.transport import BaseTransport - -class TestBaseTransport: +class TestBasicTransport: """Test transport module, base part.""" - base_comm_name = "test comm" - base_reconnect_delay = 1 - base_reconnect_delay_max = 3.5 - base_timeout_connect = 2 - base_framer = ModbusFramer - base_host = "test host" - base_port = 502 - base_server_hostname = "server test host" - base_baudrate = 9600 - base_bytesize = 8 - base_parity = "e" - base_stopbits = 2 - cwd = None - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def __init__(self): - """Initialize.""" - super().__init__( - TestBaseTransport.base_comm_name, - [ - TestBaseTransport.base_reconnect_delay * 1000, - TestBaseTransport.base_reconnect_delay_max * 1000, - ], - TestBaseTransport.base_timeout_connect * 1000, - TestBaseTransport.base_framer, - None, - None, - None, - ) - self.abort = mock.MagicMock() - self.close = mock.MagicMock() - - @classmethod - async def setup_BaseTransport(cls): - """Create base object.""" - base = BaseTransport( - cls.base_comm_name, - (cls.base_reconnect_delay * 1000, cls.base_reconnect_delay_max * 1000), - cls.base_timeout_connect * 1000, - cls.base_framer, - mock.MagicMock(), - mock.MagicMock(), - mock.MagicMock(), + async def test_init(self, transport, commparams): + """Test init()""" + commparams.done = False + assert transport.comm_params == commparams + assert ( + transport.cb_connection_made._extract_mock_name() # pylint: disable=protected-access + == "cb_connection_made" ) - params = base.CommParamsClass( - done=True, - comm_name=cls.base_comm_name, - reconnect_delay=cls.base_reconnect_delay, - reconnect_delay_max=cls.base_reconnect_delay_max, - timeout_connect=cls.base_timeout_connect, - framer=cls.base_framer, + assert ( + transport.cb_connection_lost._extract_mock_name() # pylint: disable=protected-access + == "cb_connection_lost" ) - cls.cwd = os.getcwd().split("/")[-1] - if cls.cwd == "transport": - cls.cwd = "../../" - elif cls.cwd == "test": - cls.cwd = "../" - else: - cls.cwd = "" - cls.cwd = cls.cwd + "examples/certificates/pymodbus." - return base, params - - async def test_init(self): - """Test init()""" - base, params = await self.setup_BaseTransport() - params.done = False - assert base.comm_params == params - - assert base.cb_connection_made - assert base.cb_connection_lost - assert base.cb_handle_data - assert not base.reconnect_delay_current - assert not base.reconnect_timer + assert ( + transport.cb_handle_data._extract_mock_name() # pylint: disable=protected-access + == "cb_handle_data" + ) + assert not transport.reconnect_delay_current + assert not transport.reconnect_task - async def test_property_done(self): + async def test_property_done(self, transport): """Test done property""" - base, params = await self.setup_BaseTransport() - base.comm_params.check_done() + transport.comm_params.check_done() with pytest.raises(RuntimeError): - base.comm_params.check_done() + transport.comm_params.check_done() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_unix(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_unix(setup_server, self.base_host) - params.host = self.base_host - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.skipif( - not pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_unix_windows(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - with pytest.raises(RuntimeError): - base.setup_unix(setup_server, self.base_host) - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_tcp(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_tcp(setup_server, self.base_host, self.base_port) - params.host = self.base_host - params.port = self.base_port - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_udp(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_udp(setup_server, self.base_host, self.base_port) - params.host = self.base_host - params.port = self.base_port - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - @pytest.mark.parametrize("sslctx", [None, "test ctx"]) - async def test_properties_tls(self, setup_server, sslctx): - """Test properties.""" - base, params = await self.setup_BaseTransport() - with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): - base.setup_tls( - setup_server, - self.base_host, - self.base_port, - sslctx, - None, - None, - None, - self.base_server_hostname, - ) - params.host = self.base_host - params.port = self.base_port - params.server_hostname = self.base_server_hostname - params.ssl = sslctx if sslctx else base.comm_params.ssl - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_serial(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_serial( - setup_server, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, - ) - params.host = self.base_host - params.baudrate = self.base_baudrate - params.bytesize = self.base_bytesize - params.parity = self.base_parity - params.stopbits = self.base_stopbits - assert base.comm_params == params - assert base.call_connect_listen - - async def test_with_magic(self): + async def test_with_magic(self, transport): """Test magic.""" - base, _params = await self.setup_BaseTransport() - base.close = mock.MagicMock() - async with base: + transport.close = mock.MagicMock() + async with transport: pass - base.close.assert_called_once() + transport.close.assert_called_once() - async def test_str_magic(self): + async def test_str_magic(self, params, transport): """Test magic.""" - base, _params = await self.setup_BaseTransport() - assert str(base) == f"BaseTransport({self.base_comm_name})" + assert str(transport) == f"Transport({params.comm_name})" - async def test_connection_made(self): + async def test_connection_made(self, dummy_socket, transport, commparams): """Test connection_made().""" - base, params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_made(transport) - assert base.transport == transport - assert not base.recv_buffer - assert not base.reconnect_timer - assert base.reconnect_delay_current == params.reconnect_delay - base.cb_connection_made.assert_called_once() - base.cb_connection_lost.assert_not_called() - base.cb_handle_data.assert_not_called() - base.close() - - async def test_connection_lost(self): + transport.connection_made(dummy_socket()) + assert transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert transport.reconnect_delay_current == commparams.reconnect_delay + transport.cb_connection_made.assert_called_once() + transport.cb_connection_lost.assert_not_called() + transport.cb_handle_data.assert_not_called() + transport.close() + + async def test_connection_lost(self, transport): """Test connection_lost().""" - base, params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_lost(transport) - assert not base.transport - assert not base.recv_buffer - assert not base.reconnect_timer - assert not base.reconnect_delay_current - base.cb_connection_made.assert_not_called() - base.cb_handle_data.assert_not_called() - base.cb_connection_lost.assert_called_once() - # reconnect is only after a successful connect - base.connection_made(transport) - base.connection_lost(transport) - assert base.reconnect_timer - assert not base.transport - assert not base.recv_buffer - assert base.reconnect_timer - assert base.reconnect_delay_current == 2 * params.reconnect_delay - base.cb_connection_lost.call_count == 2 - base.close() - assert not base.reconnect_timer - - async def test_eof_received(self): + transport.connection_lost(RuntimeError("not implemented")) + assert not transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert not transport.reconnect_delay_current + transport.cb_connection_made.assert_not_called() + transport.cb_handle_data.assert_not_called() + transport.cb_connection_lost.assert_called_once() + + transport.transport = mock.Mock() + transport.connection_lost(RuntimeError("not implemented")) + assert not transport.transport + assert transport.reconnect_task + transport.close() + assert not transport.reconnect_task + + async def test_eof_received(self, transport): """Test connection_lost().""" - base, params = await self.setup_BaseTransport() - self.dummy_transport() - base.eof_received() - assert not base.transport - assert not base.recv_buffer - assert not base.reconnect_timer - assert not base.reconnect_delay_current - - async def test_close(self): - """Test close().""" - base, _params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_made(transport) - base.cb_connection_made.reset_mock() - base.cb_connection_lost.reset_mock() - base.cb_handle_data.reset_mock() - base.recv_buffer = b"abc" - base.reconnect_timer = mock.MagicMock() - base.close() - transport.abort.assert_called_once() - transport.close.assert_called_once() - base.cb_connection_made.assert_not_called() - base.cb_connection_lost.assert_not_called() - base.cb_handle_data.assert_not_called() - assert not base.recv_buffer - assert not base.reconnect_timer + transport.eof_received() + assert not transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert not transport.reconnect_delay_current - async def test_reset_delay(self): + async def test_close(self, dummy_socket, transport): + """Test close().""" + socket = dummy_socket() + transport.connection_made(socket) + transport.cb_connection_made.reset_mock() + transport.cb_connection_lost.reset_mock() + transport.cb_handle_data.reset_mock() + transport.recv_buffer = b"abc" + transport.reconnect_task = mock.MagicMock() + transport.close() + socket.abort.assert_called_once() + socket.close.assert_called_once() + transport.cb_connection_made.assert_not_called() + transport.cb_connection_lost.assert_not_called() + transport.cb_handle_data.assert_not_called() + assert not transport.recv_buffer + assert not transport.reconnect_task + + async def test_reset_delay(self, transport, commparams): """Test reset_delay().""" - base, _params = await self.setup_BaseTransport() - base.reconnect_delay_current = self.base_reconnect_delay + 1 - base.reset_delay() - assert base.reconnect_delay_current == self.base_reconnect_delay + transport.reconnect_delay_current += 5.17 + transport.reset_delay() + assert transport.reconnect_delay_current == commparams.reconnect_delay - async def test_datagram(self): + async def test_datagram(self, transport): """Test datagram_received().""" - base, _params = await self.setup_BaseTransport() - base.data_received = mock.MagicMock() - base.datagram_received(b"abc", "127.0.0.1") - base.data_received.assert_called_once() + transport.data_received = mock.MagicMock() + transport.datagram_received(b"abc", "127.0.0.1") + transport.data_received.assert_called_once() - async def test_data(self): + async def test_data(self, transport): """Test data_received.""" - base, _params = await self.setup_BaseTransport() - base.cb_handle_data = mock.MagicMock(return_value=2) - base.data_received(b"123456") - base.cb_handle_data.assert_called_once() - assert base.recv_buffer == b"3456" - base.data_received(b"789") - assert base.recv_buffer == b"56789" - - async def test_send(self): + transport.cb_handle_data = mock.MagicMock(return_value=2) + transport.data_received(b"123456") + transport.cb_handle_data.assert_called_once() + assert transport.recv_buffer == b"3456" + transport.data_received(b"789") + assert transport.recv_buffer == b"56789" + + async def test_send(self, transport, params): """Test send().""" - base, _params = await self.setup_BaseTransport() - base.transport = mock.AsyncMock() - await base.send(b"abc") - - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - async def test_connect_unix(self): - """Test connect_unix().""" - base, _params = await self.setup_BaseTransport() - base.setup_unix(False, self.base_host) - base.close = mock.Mock() - mocker = mock.AsyncMock() + transport.transport = mock.AsyncMock() + await transport.send(b"abc") - base.loop.create_unix_connection = mocker - mocker.side_effect = FileNotFoundError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + transport.setup_udp(False, params.host, params.port) + await transport.send(b"abc") - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.called_once() - - async def test_connect_tcp(self): - """Test connect_tcp().""" - base, _params = await self.setup_BaseTransport() - base.setup_tcp(False, self.base_host, self.base_port) - base.close = mock.Mock() - mocker = mock.AsyncMock() + async def test_handle_listen(self, transport): + """Test handle_listen().""" + assert transport == transport.handle_listen() - base.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + async def test_reconnect_connect(self, transport): + """Test handle_listen().""" + transport.comm_params.reconnect_delay = 0.01 + transport.transport_connect = mock.AsyncMock( + side_effect=[(None, None), (117, 118)] + ) + await transport.reconnect_connect() + assert ( + transport.reconnect_delay_current + == transport.comm_params.reconnect_delay * 4 + ) + assert not transport.reconnect_task + transport.transport_connect = mock.AsyncMock( + side_effect=asyncio.CancelledError("stop loop") + ) + await transport.reconnect_connect() + assert ( + transport.reconnect_delay_current == transport.comm_params.reconnect_delay + ) + assert not transport.reconnect_task - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() - async def test_connect_tls(self): - """Test connect_tcls().""" - base, _params = await self.setup_BaseTransport() - base.setup_tls( - False, - self.base_host, - self.base_port, - "no ssl", - None, - None, - None, - self.base_server_hostname, - ) - base.close = mock.Mock() - mocker = mock.AsyncMock() +@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented") +class TestBasicUnixTransport: + """Test transport module, unix part.""" - base.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_unix(setup_server, params.host) + commparams.host = params.host + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties_windows(self, params, setup_server, transport): + """Test properties.""" + with mock.patch( + "pymodbus.transport.transport.sys.platform", return_value="windows" + ), pytest.raises(RuntimeError): + transport.setup_unix(setup_server, params.host) - async def test_connect_udp(self): - """Test connect_udp().""" - base, _params = await self.setup_BaseTransport() - base.setup_udp(False, self.base_host, self.base_port) - base.close = mock.Mock() + async def test_connect(self, params, transport): + """Test connect_unix().""" + transport.setup_unix(False, params.host) mocker = mock.AsyncMock() - - base.loop.create_datagram_endpoint = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() + transport.loop.create_unix_connection = mocker + mocker.side_effect = FileNotFoundError("testing") + assert await transport.transport_connect() == (None, None) mocker.side_effect = None - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - async def test_connect_serial(self): - """Test connect_serial().""" - base, _params = await self.setup_BaseTransport() - base.setup_serial( - False, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, - ) - base.close = mock.Mock() + async def test_listen(self, params, transport): + """Test listen_unix().""" + transport.setup_unix(True, params.host) mocker = mock.AsyncMock() + transport.loop.create_unix_server = mocker + mocker.side_effect = OSError("testing") + assert await transport.transport_listen() is None + mocker.side_effect = None - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - async def test_listen_unix(self): - """Test listen_unix().""" - base, _params = await self.setup_BaseTransport() - base.setup_unix(True, self.base_host) - base.close = mock.Mock() - mocker = mock.AsyncMock() +class TestBasicTcpTransport: + """Test transport module, tcp part.""" - base.loop.create_unix_server = mocker - mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_tcp(setup_server, params.host, params.port) + commparams.host = params.host + commparams.port = params.port + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_tcp().""" + transport.setup_tcp(False, params.host, params.port) + mocker = mock.AsyncMock() + transport.loop.create_connection = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - async def test_listen_tcp(self): + async def test_listen(self, params, transport): """Test listen_tcp().""" - base, _params = await self.setup_BaseTransport() - base.setup_tcp(True, self.base_host, self.base_port) - base.close = mock.Mock() + transport.setup_tcp(True, params.host, params.port) mocker = mock.AsyncMock() - - base.loop.create_server = mocker + transport.loop.create_server = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() + + +class TestBasicTlsTransport: + """Test transport module, tls part.""" - async def test_listen_tls(self): + @pytest.mark.parametrize("setup_server", [True, False]) + @pytest.mark.parametrize("sslctx", [None, "test ctx"]) + def test_properties(self, setup_server, sslctx, params, transport, commparams): + """Test properties.""" + with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): + transport.setup_tls( + setup_server, + params.host, + params.port, + sslctx, + "certfile dummy", + None, + None, + params.server_hostname, + ) + commparams.host = params.host + commparams.port = params.port + commparams.server_hostname = params.server_hostname + commparams.ssl = sslctx if sslctx else transport.comm_params.ssl + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_tcls().""" + transport.setup_tls( + False, + params.host, + params.port, + "no ssl", + None, + None, + None, + params.server_hostname, + ) + mocker = mock.AsyncMock() + transport.loop.create_connection = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() + + async def test_listen(self, params, transport): """Test listen_tls().""" - base, _params = await self.setup_BaseTransport() - base.setup_tls( + transport.setup_tls( True, - self.base_host, - self.base_port, + params.host, + params.port, "no ssl", None, None, None, - self.base_server_hostname, + params.server_hostname, ) - base.close = mock.Mock() mocker = mock.AsyncMock() - - base.loop.create_server = mocker + transport.loop.create_server = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() - async def test_listen_udp(self): - """Test listen_udp().""" - base, _params = await self.setup_BaseTransport() - base.setup_udp(True, self.base_host, self.base_port) - base.close = mock.Mock() + +class TestBasicUdpTransport: + """Test transport module, udp part.""" + + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_udp(setup_server, params.host, params.port) + commparams.host = params.host + commparams.port = params.port + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_udp().""" + transport.setup_udp(False, params.host, params.port) mocker = mock.AsyncMock() + transport.loop.create_datagram_endpoint = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - base.loop.create_datagram_endpoint = mocker + async def test_listen(self, params, transport): + """Test listen_udp().""" + transport.setup_udp(True, params.host, params.port) + mocker = mock.AsyncMock() + transport.loop.create_datagram_endpoint = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = (117, 118) - assert await base.transport_listen() == 117 - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert await transport.transport_listen() == mocker.return_value[0] + transport.close() - async def test_listen_serial(self): + +class TestBasicSerialTransport: + """Test transport module, serial part.""" + + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_serial( + setup_server, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, + ) + commparams.host = params.host + commparams.baudrate = params.baudrate + commparams.bytesize = params.bytesize + commparams.parity = params.parity + commparams.stopbits = params.stopbits + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_serial().""" + transport.setup_serial( + False, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, + ) + mocker = mock.AsyncMock() + with mock.patch( + "pymodbus.transport.transport.create_serial_connection", new=mocker + ): + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() + + async def test_listen(self, params, transport): """Test listen_serial().""" - base, _params = await self.setup_BaseTransport() - base.setup_serial( + transport.setup_serial( True, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, ) - base.close = mock.Mock() mocker = mock.AsyncMock() - with mock.patch( "pymodbus.transport.transport.create_serial_connection", new=mocker ): mocker.side_effect = SerialException("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert await base.transport_listen() == 117 - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert await transport.transport_listen() == mocker.return_value + transport.close() diff --git a/test/transport/test_comm.py b/test/transport/test_comm.py index e810fa332..5ea591916 100644 --- a/test/transport/test_comm.py +++ b/test/transport/test_comm.py @@ -1,382 +1,253 @@ """Test transport.""" -import asyncio -import os -import sys import time from tempfile import gettempdir import pytest -from pymodbus.framer import ModbusFramer, ModbusSocketFramer -from pymodbus.transport.transport import BaseTransport +BASE_PORT = 5200 -class TestCommTransport: - """Test for the transport module.""" - cwd = None - - @classmethod - def setup_CWD(cls): - """Get path to certificates.""" - cls.cwd = os.getcwd().split("/")[-1] - if cls.cwd == "transport": - cls.cwd = "../../" - elif cls.cwd == "test": - cls.cwd = "../" - else: - cls.cwd = "" - cls.cwd = cls.cwd + "examples/certificates/pymodbus." - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def cb_connection_made(self): - """Handle callback.""" - - def cb_connection_lost(self, _exc): - """Handle callback.""" - - def cb_handle_data(self, _data): - """Handle callback.""" - return 0 - - def __init__(self, framer: ModbusFramer, comm_name="test comm"): - """Initialize.""" - super().__init__( - comm_name, - [2500, 9000], - 2000, - framer, - self.cb_connection_made, - self.cb_connection_lost, - self.cb_handle_data, - ) - - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." +@pytest.fixture(name="domain_host") +def get_domain_host(positive): + """Get test host.""" + return "localhost" if positive else "/illegal_host_name" + + +@pytest.fixture(name="domain_socket") +def get_domain_socket(positive): + """Get test file.""" + return ( + gettempdir() + "/test_unix_" + str(time.time()) + if positive + else "/illegal_file_name" ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_unix(self): + + +@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented.") +class TestCommUnixTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, domain_socket): """Test connect_unix().""" - client = self.dummy_transport(ModbusSocketFramer) - domain_socket = "/domain_unix" - client.setup_unix(False, domain_socket) + transport.setup_unix(False, domain_socket) start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() + + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_socket): + """Test listen_unix().""" + transport_server.setup_unix(True, domain_socket) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + if server: + server.close() + transport_server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_socket): + """Test listen/connect unix().""" + transport_server.setup_unix(True, domain_socket) + await transport_server.transport_listen() + + transport.setup_unix(False, domain_socket) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() - client = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/domain_unix" - client.setup_unix(False, domain_socket) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_tcp(self): +class TestCommTcpTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, domain_host): """Test connect_tcp().""" - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "142.250.200.78", 502) + transport.setup_tcp(False, domain_host, BASE_PORT + 1) start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "localhost", 5001) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_host): + """Test listen_tcp().""" + transport_server.setup_tcp(True, domain_host, BASE_PORT + 2) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_host): + """Test listen/connect tcp().""" + transport_server.setup_tcp(True, domain_host, BASE_PORT + 3) + server = await transport_server.transport_listen() + assert server + transport.setup_tcp(False, domain_host, BASE_PORT + 3) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() + server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_tls(self): - """Test connect_tls().""" - self.setup_CWD() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( - False, - "142.250.200.78", - 502, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", - ) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( +class TestCommTlsTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, params, domain_host): + """Test connect_tls().""" + transport.setup_tls( False, - "127.0.0.1", - 5001, + domain_host, + BASE_PORT + 5, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_serial(self): - """Test connect_serial().""" - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( - False, - "no_port", - 9600, - 8, - "E", - 2, - ) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( - False, - "unix:/localhost:5001", - 9600, - 8, - "E", - 2, - ) - start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_unix(self): - """Test listen_unix().""" - server = self.dummy_transport(ModbusSocketFramer) - domain_socket = "/test_unix_" - server.setup_unix(True, domain_socket) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/test_unix_" + str(time.time()) - server.setup_unix(True, domain_socket) - assert await server.transport_listen() - assert server.transport - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_tcp(self): - """Test listen_tcp().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tcp(True, "10.0.0.1", 5101) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tcp(True, "localhost", 5101) - assert await server.transport_listen() - assert server.transport - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_tls(self): + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, params, positive, domain_host): """Test listen_tls().""" - self.setup_CWD() - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tls( + transport_server.setup_tls( True, - "10.0.0.1", - 5101, + domain_host, + BASE_PORT + 6, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tls( + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, params, domain_host): + """Test listen/connect tls().""" + transport_server.setup_tls( True, - "127.0.0.1", - 5101, + domain_host, + BASE_PORT + 7, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) - assert await server.transport_listen() - assert server.transport + server = await transport_server.transport_listen() + assert server + + transport.setup_tcp(False, domain_host, BASE_PORT + 7) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_udp(self): + +class TestCommUdpTransport: + """Test for the transport module.""" + + async def test_connect(self): + """Test connect_udp().""" + # always true, since udp is connectionless. + + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_host): """Test listen_udp().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_udp(True, "10.0.0.1", 5101) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_udp(True, "localhost", 5101) - assert await server.transport_listen() - assert server.transport + transport_server.setup_udp(True, domain_host, BASE_PORT + 10) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_host): + """Test listen/connect udp().""" + transport_server.setup_udp(True, domain_host, BASE_PORT + 11) + server = await transport_server.transport_listen() + assert server + transport.setup_udp(False, domain_host, BASE_PORT + 11) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_serial(self): - """Test listen_serial().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_serial( - True, - "no port", + +class TestCommSerialTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, positive): + """Test connect_serial().""" + domain_port = ( + f"unix:/localhost:{BASE_PORT + 15}" if positive else "/illegal_port" + ) + transport.setup_serial( + False, + domain_port, 9600, 8, "E", 2, ) - assert not await server.transport_listen() - assert not server.transport - - # there are no positive test, since there are no standard tty port + start = time.time() + assert await transport.transport_connect() == (None, None) + delta = time.time() - start + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_unix(self): - """Test listen/connect unix().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/test_unix_" + str(time.time()) - server_protocol.setup_unix(True, domain_socket) - await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_unix(False, domain_socket) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_tcp(self): - """Test listen/connect tcp().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tcp(True, "localhost", 5101) - assert await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_tls(self): - """Test listen/connect tls().""" - self.setup_CWD() - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tls( + async def test_listen(self, transport_server): + """Test listen_serial().""" + transport_server.setup_serial( True, - "127.0.0.1", - 5102, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", + "/illegal_port", + 9600, + 8, + "E", + 2, ) - assert await server_protocol.transport_listen() + server = await transport_server.transport_listen() + assert not server + assert not transport_server.transport + transport_server.close() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( - False, - "127.0.0.1", - 5102, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", - ) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() + # there are no positive test, since there are no standard tty port - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_udp(self): - """Test listen/connect udp().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_udp(True, "localhost", 5101) - transport = await server_protocol.transport_listen() - assert transport - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_udp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_serial(self): + async def test_connected(self, transport, transport_server): """Test listen/connect serial().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tcp(True, "localhost", 5101) - assert await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( + transport_server.setup_tcp(True, "localhost", BASE_PORT + 16) + server = await transport_server.transport_listen() + assert server + transport.setup_serial( False, - "unix:localhost:5001", + f"socket://localhost:{BASE_PORT + 16}", 9600, 8, "E", 2, ) - assert await client.transport_connect() == (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_reconnect(self): - """Test connect() reconnecting.""" - server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode") - server.setup_tcp(True, "localhost", 5101) - await server.transport_listen() - assert server.transport - - client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode") - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - server.close() - count = 100 - while client.transport and count: - await asyncio.sleep(0.1) - count -= 1 - if not sys.platform.startswith("win"): - assert not client.transport - assert client.reconnect_timer - assert ( - client.reconnect_delay_current == 2 * client.comm_params.reconnect_delay - ) - await asyncio.sleep(client.reconnect_delay_current * 1.2) - assert client.transport - assert client.reconnect_timer - assert client.reconnect_delay_current == client.comm_params.reconnect_delay - client.close() + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() diff --git a/test/transport/test_data.py b/test/transport/test_data.py index 035e3afb9..a48333f26 100644 --- a/test/transport/test_data.py +++ b/test/transport/test_data.py @@ -1,55 +1,27 @@ """Test transport.""" import asyncio -import pytest -from pymodbus.framer import ModbusFramer, ModbusSocketFramer -from pymodbus.transport.transport import BaseTransport +BASE_PORT = 5260 -class TestDataTransport: +class TestDataTransport: # pylint: disable=too-few-public-methods """Test for the transport module.""" - class dummy_transport(BaseTransport): - """Transport class for test.""" + async def test_client_send(self, transport, transport_server): + """Test send().""" + transport_server.setup_tcp(True, "localhost", BASE_PORT + 1) + server = await transport_server.transport_listen() + assert transport_server.transport - def cb_connection_made(self): - """Handle callback.""" - - def cb_connection_lost(self, _exc): - """Handle callback.""" - - def cb_handle_data(self, _data): - """Handle callback.""" - return 0 - - def __init__(self, framer: ModbusFramer, comm_name="test comm"): - """Initialize.""" - super().__init__( - comm_name, - [2500, 9000], - 2000, - framer, - self.cb_connection_made, - self.cb_connection_lost, - self.cb_handle_data, - ) - - @pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows problem.") - @pytest.mark.xdist_group(name="server_serialize") - async def test_client_send(self): - """Test connect() reconnecting.""" - server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode") - server.setup_tcp(True, "localhost", 5101) - await server.transport_listen() - assert server.transport - - client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode") - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - await client.send(b"ABC") + transport.setup_tcp(False, "localhost", BASE_PORT + 1) + assert await transport.transport_connect() != (None, None) + await transport.send(b"ABC") await asyncio.sleep(2) - assert server.recv_buffer == b"ABC" - await server.send(b"DEF") + assert transport_server.recv_buffer == b"ABC" + await transport_server.send(b"DEF") await asyncio.sleep(2) - assert client.recv_buffer == b"DEF" + assert transport.recv_buffer == b"DEF" + transport.close() + transport_server.close() + server.close() diff --git a/test/transport/test_reconnect.py b/test/transport/test_reconnect.py new file mode 100644 index 000000000..12943948b --- /dev/null +++ b/test/transport/test_reconnect.py @@ -0,0 +1,64 @@ +"""Test transport.""" +import asyncio +from unittest import mock + + +BASE_PORT = 5250 + + +class TestReconnectTransport: + """Test transport module, base part.""" + + async def test_no_reconnect_call(self, transport, commparams): + """Test connection_lost().""" + transport.setup_tcp(False, "localhost", BASE_PORT + 1) + transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) + transport.connection_made(mock.Mock()) + assert not transport.call_connect_listen.call_count + assert transport.reconnect_delay_current == commparams.reconnect_delay + transport.connection_lost(RuntimeError("Connection lost")) + assert not transport.call_connect_listen.call_count + assert transport.reconnect_delay_current == commparams.reconnect_delay + transport.close() + + async def test_reconnect_call(self, transport, commparams): + """Test connection_lost().""" + transport.setup_tcp(False, "localhost", BASE_PORT + 2) + transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) + transport.connection_made(mock.Mock()) + transport.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(transport.reconnect_delay_current * 1.2) + assert transport.call_connect_listen.call_count == 1 + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 + transport.close() + + async def test_multi_reconnect_call(self, transport, commparams): + """Test connection_lost().""" + transport.setup_tcp(False, "localhost", BASE_PORT + 3) + transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) + transport.connection_made(mock.Mock()) + transport.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(transport.reconnect_delay_current * 1.2) + assert transport.call_connect_listen.call_count == 1 + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 + await asyncio.sleep(transport.reconnect_delay_current * 1.2) + assert transport.call_connect_listen.call_count == 2 + assert transport.reconnect_delay_current == commparams.reconnect_delay_max + await asyncio.sleep(transport.reconnect_delay_current * 1.2) + assert transport.call_connect_listen.call_count == 3 + assert transport.reconnect_delay_current == commparams.reconnect_delay_max + transport.close() + + async def test_reconnect_call_ok(self, transport, commparams): + """Test connection_lost().""" + transport.setup_tcp(False, "localhost", BASE_PORT + 4) + transport.call_connect_listen = mock.AsyncMock( + return_value=(mock.Mock(), mock.Mock()) + ) + transport.connection_made(mock.Mock()) + transport.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(transport.reconnect_delay_current * 1.2) + assert transport.call_connect_listen.call_count == 1 + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 + assert not transport.reconnect_task + transport.close()