From c6375c8066a7321bf472b0cdf012a55e704bf55c Mon Sep 17 00:00:00 2001 From: jan iversen Date: Mon, 5 Jun 2023 20:18:18 +0200 Subject: [PATCH] transport fixes and 100% test coverage. (#1580) --- pymodbus/client/base.py | 10 +- pymodbus/transport/__init__.py | 6 - pymodbus/transport/transport.py | 43 +- test/transport/__init__.py | 1 + test/transport/conftest.py | 88 ++++ test/transport/test_basic.py | 759 ++++++++++++++----------------- test/transport/test_comm.py | 504 ++++++++------------ test/transport/test_data.py | 27 ++ test/transport/test_reconnect.py | 103 +---- test/transport/xtest_data.py | 55 --- 10 files changed, 691 insertions(+), 905 deletions(-) create mode 100644 test/transport/__init__.py create mode 100644 test/transport/conftest.py create mode 100644 test/transport/test_data.py delete mode 100644 test/transport/xtest_data.py diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index 39c3ca101..030cd41cd 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -14,11 +14,11 @@ from pymodbus.logging import Log from pymodbus.pdu import ModbusRequest, ModbusResponse from pymodbus.transaction import DictTransactionManager -from pymodbus.transport import BaseTransport +from pymodbus.transport.transport import Transport from pymodbus.utilities import ModbusTransactionState -class ModbusBaseClient(ModbusClientMixin, BaseTransport): +class ModbusBaseClient(ModbusClientMixin, Transport): """**ModbusBaseClient** **Parameters common to all clients**: @@ -94,12 +94,12 @@ def __init__( # pylint: disable=too-many-arguments **kwargs: Any, ) -> None: """Initialize a client instance.""" - BaseTransport.__init__( + Transport.__init__( self, "comm", - (reconnect_delay * 1000, reconnect_delay_max * 1000), + reconnect_delay * 1000, + reconnect_delay_max * 1000, timeout * 1000, - framer, lambda: None, self.cb_base_connection_lost, self.cb_base_handle_data, diff --git a/pymodbus/transport/__init__.py b/pymodbus/transport/__init__.py index 2d5c29eaa..d96b47771 100644 --- a/pymodbus/transport/__init__.py +++ b/pymodbus/transport/__init__.py @@ -1,7 +1 @@ """Transport.""" - -__all__ = [ - "BaseTransport", -] - -from pymodbus.transport.transport import BaseTransport diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index cbdf5a504..dddf70ca8 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -10,17 +10,22 @@ from dataclasses import dataclass from typing import Any, Callable, Coroutine -from pymodbus.framer import ModbusFramer from pymodbus.logging import Log from pymodbus.transport.serial_asyncio import create_serial_connection -class BaseTransport: - """Base class for transport types. +class Transport: + """Transport layer. - BaseTransport contains functions common to all transport types and client/server. + Contains pure transport methods needed to connect/listen, send/receive and close connections + for unix socket, tcp, tls and serial communications. - This class is not available in the pymodbus API, and should not be referenced in Applications. + Contains high level methods like reconnect. + + This class is not available in the pymodbus API, and should not be referenced in Applications + nor in the pymodbus documentation. + + The class is designed to be an object in the message level class. """ @dataclass @@ -33,7 +38,6 @@ class CommParamsClass: reconnect_delay: float = None reconnect_delay_max: float = None timeout_connect: float = None - framer: ModbusFramer = None # tcp / tls / udp / serial host: str = None @@ -60,9 +64,9 @@ def check_done(self): def __init__( self, comm_name: str, - reconnect_delay: tuple[int, int], + reconnect_delay: int, + reconnect_max: int, timeout_connect: int, - framer: ModbusFramer, callback_connected: Callable[[], None], callback_disconnected: Callable[[Exception], None], callback_data: Callable[[bytes], int], @@ -70,9 +74,9 @@ def __init__( """Initialize a transport instance. :param comm_name: name of this transport connection - :param reconnect_delay: delay and max in milliseconds for first reconnect (0,0 for no reconnect) + :param reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect) + :param reconnect_delay: max reconnect delay in milliseconds :param timeout_connect: Max. time in milliseconds for connect to complete - :param framer: Modbus framer to decode/encode messagees. :param callback_connected: Called when connection is established :param callback_disconnected: Called when connection is disconnected :param callback_data: Called when data is received @@ -84,19 +88,18 @@ def __init__( # properties, can be read, but may not be mingled with self.comm_params = self.CommParamsClass( comm_name=comm_name, - reconnect_delay=reconnect_delay[0] / 1000, - reconnect_delay_max=reconnect_delay[1] / 1000, + reconnect_delay=reconnect_delay / 1000, + reconnect_delay_max=reconnect_max / 1000, timeout_connect=timeout_connect / 1000, - framer=framer, ) - self.reconnect_delay_current: float = 0 + self.reconnect_delay_current: float = 0.0 self.transport: asyncio.BaseTransport | asyncio.Server = None self.protocol: asyncio.BaseProtocol = None self.loop: asyncio.AbstractEventLoop = None with suppress(RuntimeError): self.loop = asyncio.get_running_loop() - self.reconnect_timer: asyncio.Task = None + self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None self.use_udp = False @@ -314,7 +317,7 @@ def connection_lost(self, reason: Exception): self.cb_connection_lost(reason) if self.transport: self.close() - self.reconnect_timer = asyncio.create_task(self.reconnect_connect()) + self.reconnect_task = asyncio.create_task(self.reconnect_connect()) def eof_received(self): """Call when eof received (other end closed connection). @@ -360,9 +363,9 @@ def close(self, reconnect: bool = False) -> None: self.transport.close() self.transport = None self.protocol = None - if not reconnect and self.reconnect_timer: - self.reconnect_timer.cancel() - self.reconnect_timer = None + if not reconnect and self.reconnect_task: + self.reconnect_task.cancel() + self.reconnect_task = None self.recv_buffer = b"" def reset_delay(self) -> None: @@ -395,7 +398,7 @@ async def reconnect_connect(self): ) except asyncio.CancelledError: pass - self.reconnect_timer = None + self.reconnect_task = None # ----------------- # # The magic methods # diff --git a/test/transport/__init__.py b/test/transport/__init__.py new file mode 100644 index 000000000..430da4624 --- /dev/null +++ b/test/transport/__init__.py @@ -0,0 +1 @@ +"""Test of transport layer.""" diff --git a/test/transport/conftest.py b/test/transport/conftest.py new file mode 100644 index 000000000..4140fad59 --- /dev/null +++ b/test/transport/conftest.py @@ -0,0 +1,88 @@ +"""Test transport.""" +import os +from dataclasses import dataclass +from unittest import mock + +import pytest +import pytest_asyncio + +from pymodbus.transport.transport import Transport + + +@dataclass +class BaseParams(Transport.CommParamsClass): + """Base parameters for all transport testing.""" + + comm_name = "test comm" + reconnect_delay = 1000 + reconnect_delay_max = 3500 + timeout_connect = 2000 + host = "test host" + port = 502 + server_hostname = "server test host" + baudrate = 9600 + bytesize = 8 + parity = "e" + stopbits = 2 + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + + +@pytest.fixture(name="params") +def prepare_baseparams(): + """Prepare BaseParams class.""" + return BaseParams + + +class DummySocket: # pylint: disable=too-few-public-methods + """Socket simulator for test.""" + + def __init__(self): + """Initialize.""" + self.close = mock.Mock() + self.abort = mock.Mock() + + +@pytest.fixture(name="dummy_socket") +def prepare_dummysocket(): + """Prepare dummy_socket class.""" + return DummySocket + + +@pytest.fixture(name="commparams") +def prepare_testparams(): + """Prepare CommParamsClass object.""" + return Transport.CommParamsClass( + done=True, + comm_name=BaseParams.comm_name, + reconnect_delay=BaseParams.reconnect_delay / 1000, + reconnect_delay_max=BaseParams.reconnect_delay_max / 1000, + timeout_connect=BaseParams.timeout_connect / 1000, + ) + + +@pytest.fixture(name="transport") +async def prepare_transport(): + """Prepare transport object.""" + return Transport( + BaseParams.comm_name, + BaseParams.reconnect_delay, + BaseParams.reconnect_delay_max, + BaseParams.timeout_connect, + mock.Mock(name="cb_connection_made"), + mock.Mock(name="cb_connection_lost"), + mock.Mock(name="cb_handle_data", return_value=0), + ) + + +@pytest_asyncio.fixture(name="transport_server") +async def prepare_transport_server(): + """Prepare transport object.""" + return Transport( + BaseParams.comm_name, + BaseParams.reconnect_delay, + BaseParams.reconnect_delay_max, + BaseParams.timeout_connect, + mock.Mock(name="cb_connection_made"), + mock.Mock(name="cb_connection_lost"), + mock.Mock(name="cb_handle_data", return_value=0), + ) diff --git a/test/transport/test_basic.py b/test/transport/test_basic.py index ac73857c5..ad8816eae 100644 --- a/test/transport/test_basic.py +++ b/test/transport/test_basic.py @@ -1,506 +1,423 @@ """Test transport.""" import asyncio -import os from unittest import mock import pytest from serial import SerialException -from pymodbus.framer import ModbusFramer -from pymodbus.transport.transport import BaseTransport - -class TestBaseTransport: +class TestBasicTransport: """Test transport module, base part.""" - base_comm_name = "test comm" - base_reconnect_delay = 1 - base_reconnect_delay_max = 3.5 - base_timeout_connect = 2 - base_framer = ModbusFramer - base_host = "test host" - base_port = 502 - base_server_hostname = "server test host" - base_baudrate = 9600 - base_bytesize = 8 - base_parity = "e" - base_stopbits = 2 - cwd = None - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def __init__(self): - """Initialize.""" - super().__init__( - TestBaseTransport.base_comm_name, - [ - TestBaseTransport.base_reconnect_delay * 1000, - TestBaseTransport.base_reconnect_delay_max * 1000, - ], - TestBaseTransport.base_timeout_connect * 1000, - TestBaseTransport.base_framer, - None, - None, - None, - ) - self.abort = mock.MagicMock() - self.close = mock.MagicMock() - - @classmethod - async def setup_BaseTransport(cls): - """Create base object.""" - base = BaseTransport( - cls.base_comm_name, - (cls.base_reconnect_delay * 1000, cls.base_reconnect_delay_max * 1000), - cls.base_timeout_connect * 1000, - cls.base_framer, - mock.MagicMock(), - mock.MagicMock(), - mock.MagicMock(), + async def test_init(self, transport, commparams): + """Test init()""" + commparams.done = False + assert transport.comm_params == commparams + assert ( + transport.cb_connection_made._extract_mock_name() # pylint: disable=protected-access + == "cb_connection_made" ) - params = base.CommParamsClass( - done=True, - comm_name=cls.base_comm_name, - reconnect_delay=cls.base_reconnect_delay, - reconnect_delay_max=cls.base_reconnect_delay_max, - timeout_connect=cls.base_timeout_connect, - framer=cls.base_framer, + assert ( + transport.cb_connection_lost._extract_mock_name() # pylint: disable=protected-access + == "cb_connection_lost" ) - cls.cwd = os.getcwd().split("/")[-1] - if cls.cwd == "transport": - cls.cwd = "../../" - elif cls.cwd == "test": - cls.cwd = "../" - else: - cls.cwd = "" - cls.cwd = cls.cwd + "examples/certificates/pymodbus." - return base, params - - async def test_init(self): - """Test init()""" - base, params = await self.setup_BaseTransport() - params.done = False - assert base.comm_params == params - - assert base.cb_connection_made - assert base.cb_connection_lost - assert base.cb_handle_data - assert not base.reconnect_delay_current - assert not base.reconnect_timer + assert ( + transport.cb_handle_data._extract_mock_name() # pylint: disable=protected-access + == "cb_handle_data" + ) + assert not transport.reconnect_delay_current + assert not transport.reconnect_task - async def test_property_done(self): + async def test_property_done(self, transport): """Test done property""" - base, params = await self.setup_BaseTransport() - base.comm_params.check_done() + transport.comm_params.check_done() with pytest.raises(RuntimeError): - base.comm_params.check_done() + transport.comm_params.check_done() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_unix(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_unix(setup_server, self.base_host) - params.host = self.base_host - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.skipif( - not pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_unix_windows(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - with pytest.raises(RuntimeError): - base.setup_unix(setup_server, self.base_host) - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_tcp(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_tcp(setup_server, self.base_host, self.base_port) - params.host = self.base_host - params.port = self.base_port - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_udp(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_udp(setup_server, self.base_host, self.base_port) - params.host = self.base_host - params.port = self.base_port - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - @pytest.mark.parametrize("sslctx", [None, "test ctx"]) - async def test_properties_tls(self, setup_server, sslctx): - """Test properties.""" - base, params = await self.setup_BaseTransport() - with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): - base.setup_tls( - setup_server, - self.base_host, - self.base_port, - sslctx, - None, - None, - None, - self.base_server_hostname, - ) - params.host = self.base_host - params.port = self.base_port - params.server_hostname = self.base_server_hostname - params.ssl = sslctx if sslctx else base.comm_params.ssl - assert base.comm_params == params - assert base.call_connect_listen - - @pytest.mark.parametrize("setup_server", [True, False]) - async def test_properties_serial(self, setup_server): - """Test properties.""" - base, params = await self.setup_BaseTransport() - base.setup_serial( - setup_server, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, - ) - params.host = self.base_host - params.baudrate = self.base_baudrate - params.bytesize = self.base_bytesize - params.parity = self.base_parity - params.stopbits = self.base_stopbits - assert base.comm_params == params - assert base.call_connect_listen - - async def test_with_magic(self): + async def test_with_magic(self, transport): """Test magic.""" - base, _params = await self.setup_BaseTransport() - base.close = mock.MagicMock() - async with base: + transport.close = mock.MagicMock() + async with transport: pass - base.close.assert_called_once() + transport.close.assert_called_once() - async def test_str_magic(self): + async def test_str_magic(self, params, transport): """Test magic.""" - base, _params = await self.setup_BaseTransport() - assert str(base) == f"BaseTransport({self.base_comm_name})" + assert str(transport) == f"Transport({params.comm_name})" - async def test_connection_made(self): + async def test_connection_made(self, dummy_socket, transport, commparams): """Test connection_made().""" - base, params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_made(transport) - assert base.transport == transport - assert not base.recv_buffer - assert not base.reconnect_timer - assert base.reconnect_delay_current == params.reconnect_delay - base.cb_connection_made.assert_called_once() - base.cb_connection_lost.assert_not_called() - base.cb_handle_data.assert_not_called() - base.close() - - async def test_connection_lost(self): + transport.connection_made(dummy_socket()) + assert transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert transport.reconnect_delay_current == commparams.reconnect_delay + transport.cb_connection_made.assert_called_once() + transport.cb_connection_lost.assert_not_called() + transport.cb_handle_data.assert_not_called() + transport.close() + + async def test_connection_lost(self, transport): """Test connection_lost().""" - base, params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_made(transport) - base.cb_connection_made.reset_mock() - base.connection_lost(RuntimeError("not implemented")) - assert not base.transport - assert not base.recv_buffer - assert base.reconnect_timer - assert base.reconnect_delay_current - base.cb_connection_made.assert_not_called() - base.cb_handle_data.assert_not_called() - base.cb_connection_lost.assert_called_once() - # reconnect is only after a successful connect - # base.connection_made(transport) - # base.connection_lost(transport) - # assert base.reconnect_timer - # assert not base.transport - # assert not base.recv_buffer - # assert base.reconnect_timer - # assert base.reconnect_delay_current == 2 * params.reconnect_delay - # base.cb_connection_lost.call_count == 2 - # base.close() - # assert not base.reconnect_timer - - async def test_eof_received(self): + transport.connection_lost(RuntimeError("not implemented")) + assert not transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert not transport.reconnect_delay_current + transport.cb_connection_made.assert_not_called() + transport.cb_handle_data.assert_not_called() + transport.cb_connection_lost.assert_called_once() + + transport.transport = mock.Mock() + transport.connection_lost(RuntimeError("not implemented")) + assert not transport.transport + assert transport.reconnect_task + transport.close() + assert not transport.reconnect_task + + async def test_eof_received(self, transport): """Test connection_lost().""" - base, params = await self.setup_BaseTransport() - self.dummy_transport() - base.eof_received() - assert not base.transport - assert not base.recv_buffer - assert not base.reconnect_timer - assert not base.reconnect_delay_current - - async def test_close(self): - """Test close().""" - base, _params = await self.setup_BaseTransport() - transport = self.dummy_transport() - base.connection_made(transport) - base.cb_connection_made.reset_mock() - base.cb_connection_lost.reset_mock() - base.cb_handle_data.reset_mock() - base.recv_buffer = b"abc" - base.reconnect_timer = mock.MagicMock() - base.close() - transport.abort.assert_called_once() - transport.close.assert_called_once() - base.cb_connection_made.assert_not_called() - base.cb_connection_lost.assert_not_called() - base.cb_handle_data.assert_not_called() - assert not base.recv_buffer - assert not base.reconnect_timer + transport.eof_received() + assert not transport.transport + assert not transport.recv_buffer + assert not transport.reconnect_task + assert not transport.reconnect_delay_current - async def test_reset_delay(self): + async def test_close(self, dummy_socket, transport): + """Test close().""" + socket = dummy_socket() + transport.connection_made(socket) + transport.cb_connection_made.reset_mock() + transport.cb_connection_lost.reset_mock() + transport.cb_handle_data.reset_mock() + transport.recv_buffer = b"abc" + transport.reconnect_task = mock.MagicMock() + transport.close() + socket.abort.assert_called_once() + socket.close.assert_called_once() + transport.cb_connection_made.assert_not_called() + transport.cb_connection_lost.assert_not_called() + transport.cb_handle_data.assert_not_called() + assert not transport.recv_buffer + assert not transport.reconnect_task + + async def test_reset_delay(self, transport, commparams): """Test reset_delay().""" - base, _params = await self.setup_BaseTransport() - base.reconnect_delay_current = self.base_reconnect_delay + 1 - base.reset_delay() - assert base.reconnect_delay_current == self.base_reconnect_delay + transport.reconnect_delay_current += 5.17 + transport.reset_delay() + assert transport.reconnect_delay_current == commparams.reconnect_delay - async def test_datagram(self): + async def test_datagram(self, transport): """Test datagram_received().""" - base, _params = await self.setup_BaseTransport() - base.data_received = mock.MagicMock() - base.datagram_received(b"abc", "127.0.0.1") - base.data_received.assert_called_once() + transport.data_received = mock.MagicMock() + transport.datagram_received(b"abc", "127.0.0.1") + transport.data_received.assert_called_once() - async def test_data(self): + async def test_data(self, transport): """Test data_received.""" - base, _params = await self.setup_BaseTransport() - base.cb_handle_data = mock.MagicMock(return_value=2) - base.data_received(b"123456") - base.cb_handle_data.assert_called_once() - assert base.recv_buffer == b"3456" - base.data_received(b"789") - assert base.recv_buffer == b"56789" - - async def test_send(self): + transport.cb_handle_data = mock.MagicMock(return_value=2) + transport.data_received(b"123456") + transport.cb_handle_data.assert_called_once() + assert transport.recv_buffer == b"3456" + transport.data_received(b"789") + assert transport.recv_buffer == b"56789" + + async def test_send(self, transport, params): """Test send().""" - base, _params = await self.setup_BaseTransport() - base.transport = mock.AsyncMock() - await base.send(b"abc") - - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - async def test_connect_unix(self): - """Test connect_unix().""" - base, _params = await self.setup_BaseTransport() - base.setup_unix(False, self.base_host) - base.close = mock.Mock() - mocker = mock.AsyncMock() + transport.transport = mock.AsyncMock() + await transport.send(b"abc") - base.loop.create_unix_connection = mocker - mocker.side_effect = FileNotFoundError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + transport.setup_udp(False, params.host, params.port) + await transport.send(b"abc") - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.called_once() - - async def test_connect_tcp(self): - """Test connect_tcp().""" - base, _params = await self.setup_BaseTransport() - base.setup_tcp(False, self.base_host, self.base_port) - base.close = mock.Mock() - mocker = mock.AsyncMock() + async def test_handle_listen(self, transport): + """Test handle_listen().""" + assert transport == transport.handle_listen() - base.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + async def test_reconnect_connect(self, transport): + """Test handle_listen().""" + transport.comm_params.reconnect_delay = 0.01 + transport.transport_connect = mock.AsyncMock( + side_effect=[(None, None), (117, 118)] + ) + await transport.reconnect_connect() + assert ( + transport.reconnect_delay_current + == transport.comm_params.reconnect_delay * 4 + ) + assert not transport.reconnect_task + transport.transport_connect = mock.AsyncMock( + side_effect=asyncio.CancelledError("stop loop") + ) + await transport.reconnect_connect() + assert ( + transport.reconnect_delay_current == transport.comm_params.reconnect_delay + ) + assert not transport.reconnect_task - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() - async def test_connect_tls(self): - """Test connect_tcls().""" - base, _params = await self.setup_BaseTransport() - base.setup_tls( - False, - self.base_host, - self.base_port, - "no ssl", - None, - None, - None, - self.base_server_hostname, - ) - base.close = mock.Mock() - mocker = mock.AsyncMock() +@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented") +class TestBasicUnixTransport: + """Test transport module, unix part.""" - base.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_unix(setup_server, params.host) + commparams.host = params.host + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties_windows(self, params, setup_server, transport): + """Test properties.""" + with mock.patch( + "pymodbus.transport.transport.sys.platform", return_value="windows" + ), pytest.raises(RuntimeError): + transport.setup_unix(setup_server, params.host) - async def test_connect_udp(self): - """Test connect_udp().""" - base, _params = await self.setup_BaseTransport() - base.setup_udp(False, self.base_host, self.base_port) - base.close = mock.Mock() + async def test_connect(self, params, transport): + """Test connect_unix().""" + transport.setup_unix(False, params.host) mocker = mock.AsyncMock() - - base.loop.create_datagram_endpoint = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() + transport.loop.create_unix_connection = mocker + mocker.side_effect = FileNotFoundError("testing") + assert await transport.transport_connect() == (None, None) mocker.side_effect = None - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - async def test_connect_serial(self): - """Test connect_serial().""" - base, _params = await self.setup_BaseTransport() - base.setup_serial( - False, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, - ) - base.close = mock.Mock() + async def test_listen(self, params, transport): + """Test listen_unix().""" + transport.setup_unix(True, params.host) mocker = mock.AsyncMock() + transport.loop.create_unix_server = mocker + mocker.side_effect = OSError("testing") + assert await transport.transport_listen() is None + mocker.side_effect = None - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = asyncio.TimeoutError("testing") - assert await base.transport_connect() == (None, None) - base.close.assert_called_once() - mocker.side_effect = None + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() - mocker.return_value = (117, 118) - assert mocker.return_value == await base.transport_connect() - base.close.assert_called_once() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - async def test_listen_unix(self): - """Test listen_unix().""" - base, _params = await self.setup_BaseTransport() - base.setup_unix(True, self.base_host) - base.close = mock.Mock() - mocker = mock.AsyncMock() +class TestBasicTcpTransport: + """Test transport module, tcp part.""" - base.loop.create_unix_server = mocker - mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_tcp(setup_server, params.host, params.port) + commparams.host = params.host + commparams.port = params.port + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_tcp().""" + transport.setup_tcp(False, params.host, params.port) + mocker = mock.AsyncMock() + transport.loop.create_connection = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - async def test_listen_tcp(self): + async def test_listen(self, params, transport): """Test listen_tcp().""" - base, _params = await self.setup_BaseTransport() - base.setup_tcp(True, self.base_host, self.base_port) - base.close = mock.Mock() + transport.setup_tcp(True, params.host, params.port) mocker = mock.AsyncMock() - - base.loop.create_server = mocker + transport.loop.create_server = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() + + +class TestBasicTlsTransport: + """Test transport module, tls part.""" - async def test_listen_tls(self): + @pytest.mark.parametrize("setup_server", [True, False]) + @pytest.mark.parametrize("sslctx", [None, "test ctx"]) + def test_properties(self, setup_server, sslctx, params, transport, commparams): + """Test properties.""" + with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): + transport.setup_tls( + setup_server, + params.host, + params.port, + sslctx, + "certfile dummy", + None, + None, + params.server_hostname, + ) + commparams.host = params.host + commparams.port = params.port + commparams.server_hostname = params.server_hostname + commparams.ssl = sslctx if sslctx else transport.comm_params.ssl + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_tcls().""" + transport.setup_tls( + False, + params.host, + params.port, + "no ssl", + None, + None, + None, + params.server_hostname, + ) + mocker = mock.AsyncMock() + transport.loop.create_connection = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() + + async def test_listen(self, params, transport): """Test listen_tls().""" - base, _params = await self.setup_BaseTransport() - base.setup_tls( + transport.setup_tls( True, - self.base_host, - self.base_port, + params.host, + params.port, "no ssl", None, None, None, - self.base_server_hostname, + params.server_hostname, ) - base.close = mock.Mock() mocker = mock.AsyncMock() - - base.loop.create_server = mocker + transport.loop.create_server = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert mocker.return_value == await base.transport_listen() - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert mocker.return_value == await transport.transport_listen() + transport.close() - async def test_listen_udp(self): - """Test listen_udp().""" - base, _params = await self.setup_BaseTransport() - base.setup_udp(True, self.base_host, self.base_port) - base.close = mock.Mock() + +class TestBasicUdpTransport: + """Test transport module, udp part.""" + + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_udp(setup_server, params.host, params.port) + commparams.host = params.host + commparams.port = params.port + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_udp().""" + transport.setup_udp(False, params.host, params.port) mocker = mock.AsyncMock() + transport.loop.create_datagram_endpoint = mocker + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() - base.loop.create_datagram_endpoint = mocker + async def test_listen(self, params, transport): + """Test listen_udp().""" + transport.setup_udp(True, params.host, params.port) + mocker = mock.AsyncMock() + transport.loop.create_datagram_endpoint = mocker mocker.side_effect = OSError("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = (117, 118) - assert await base.transport_listen() == 117 - base.close.assert_called_once() + mocker.return_value = (mock.Mock(), mock.Mock()) + assert await transport.transport_listen() == mocker.return_value[0] + transport.close() - async def test_listen_serial(self): + +class TestBasicSerialTransport: + """Test transport module, serial part.""" + + @pytest.mark.parametrize("setup_server", [True, False]) + def test_properties(self, params, setup_server, transport, commparams): + """Test properties.""" + transport.setup_serial( + setup_server, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, + ) + commparams.host = params.host + commparams.baudrate = params.baudrate + commparams.bytesize = params.bytesize + commparams.parity = params.parity + commparams.stopbits = params.stopbits + assert transport.comm_params == commparams + assert transport.call_connect_listen + transport.close() + + async def test_connect(self, params, transport): + """Test connect_serial().""" + transport.setup_serial( + False, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, + ) + mocker = mock.AsyncMock() + with mock.patch( + "pymodbus.transport.transport.create_serial_connection", new=mocker + ): + mocker.side_effect = asyncio.TimeoutError("testing") + assert await transport.transport_connect() == (None, None) + mocker.side_effect = None + + mocker.return_value = (mock.Mock(), mock.Mock()) + assert mocker.return_value == await transport.transport_connect() + transport.close() + + async def test_listen(self, params, transport): """Test listen_serial().""" - base, _params = await self.setup_BaseTransport() - base.setup_serial( + transport.setup_serial( True, - self.base_host, - self.base_baudrate, - self.base_bytesize, - self.base_parity, - self.base_stopbits, + params.host, + params.baudrate, + params.bytesize, + params.parity, + params.stopbits, ) - base.close = mock.Mock() mocker = mock.AsyncMock() - with mock.patch( "pymodbus.transport.transport.create_serial_connection", new=mocker ): mocker.side_effect = SerialException("testing") - assert await base.transport_listen() is None - base.close.assert_called_once() + assert await transport.transport_listen() is None mocker.side_effect = None - mocker.return_value = 117 - assert await base.transport_listen() == 117 - base.close.assert_called_once() + mocker.return_value = mock.Mock() + assert await transport.transport_listen() == mocker.return_value + transport.close() diff --git a/test/transport/test_comm.py b/test/transport/test_comm.py index 7855139df..5ea591916 100644 --- a/test/transport/test_comm.py +++ b/test/transport/test_comm.py @@ -1,377 +1,253 @@ """Test transport.""" -import asyncio -import os import time from tempfile import gettempdir import pytest -from pymodbus.framer import ModbusFramer, ModbusSocketFramer -from pymodbus.transport.transport import BaseTransport +BASE_PORT = 5200 -class TestCommTransport: - """Test for the transport module.""" - cwd = None - - @classmethod - def setup_CWD(cls): - """Get path to certificates.""" - cls.cwd = os.getcwd().split("/")[-1] - if cls.cwd == "transport": - cls.cwd = "../../" - elif cls.cwd == "test": - cls.cwd = "../" - else: - cls.cwd = "" - cls.cwd = cls.cwd + "examples/certificates/pymodbus." - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def cb_connection_made(self): - """Handle callback.""" - - def cb_connection_lost(self, _exc): - """Handle callback.""" - - def cb_handle_data(self, _data): - """Handle callback.""" - return 0 - - def __init__(self, framer: ModbusFramer, comm_name="test comm"): - """Initialize.""" - super().__init__( - comm_name, - [2500, 9000], - 2000, - framer, - self.cb_connection_made, - self.cb_connection_lost, - self.cb_handle_data, - ) - - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." +@pytest.fixture(name="domain_host") +def get_domain_host(positive): + """Get test host.""" + return "localhost" if positive else "/illegal_host_name" + + +@pytest.fixture(name="domain_socket") +def get_domain_socket(positive): + """Get test file.""" + return ( + gettempdir() + "/test_unix_" + str(time.time()) + if positive + else "/illegal_file_name" ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_unix(self): + + +@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented.") +class TestCommUnixTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, domain_socket): """Test connect_unix().""" - client = self.dummy_transport(ModbusSocketFramer) - domain_socket = "/domain_unix" - client.setup_unix(False, domain_socket) + transport.setup_unix(False, domain_socket) start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() + + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_socket): + """Test listen_unix().""" + transport_server.setup_unix(True, domain_socket) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + if server: + server.close() + transport_server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_socket): + """Test listen/connect unix().""" + transport_server.setup_unix(True, domain_socket) + await transport_server.transport_listen() + + transport.setup_unix(False, domain_socket) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() - client = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/domain_unix" - client.setup_unix(False, domain_socket) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_tcp(self): +class TestCommTcpTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, domain_host): """Test connect_tcp().""" - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "142.250.200.78", 502) + transport.setup_tcp(False, domain_host, BASE_PORT + 1) start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "localhost", 5001) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_host): + """Test listen_tcp().""" + transport_server.setup_tcp(True, domain_host, BASE_PORT + 2) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_host): + """Test listen/connect tcp().""" + transport_server.setup_tcp(True, domain_host, BASE_PORT + 3) + server = await transport_server.transport_listen() + assert server + transport.setup_tcp(False, domain_host, BASE_PORT + 3) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() + server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_tls(self): - """Test connect_tls().""" - self.setup_CWD() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( - False, - "142.250.200.78", - 502, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", - ) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( +class TestCommTlsTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, params, domain_host): + """Test connect_tls().""" + transport.setup_tls( False, - "127.0.0.1", - 5001, + domain_host, + BASE_PORT + 5, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_serial(self): - """Test connect_serial().""" - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( - False, - "no_port", - 9600, - 8, - "E", - 2, - ) - start = time.time() - assert await client.transport_connect() == (None, None) - delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( - False, - "unix:/localhost:5001", - 9600, - 8, - "E", - 2, - ) - start = time.time() - assert await client.transport_connect() == (None, None) + assert await transport.transport_connect() == (None, None) delta = time.time() - start - assert delta < client.comm_params.timeout_connect * 1.2 + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_unix(self): - """Test listen_unix().""" - server = self.dummy_transport(ModbusSocketFramer) - domain_socket = "/test_unix_" - server.setup_unix(True, domain_socket) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/test_unix_" + str(time.time()) - server.setup_unix(True, domain_socket) - assert await server.transport_listen() - assert server.transport - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_tcp(self): - """Test listen_tcp().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tcp(True, "10.0.0.1", 5101) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tcp(True, "localhost", 5101) - assert await server.transport_listen() - assert server.transport - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_tls(self): + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, params, positive, domain_host): """Test listen_tls().""" - self.setup_CWD() - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tls( + transport_server.setup_tls( True, - "10.0.0.1", - 5101, + domain_host, + BASE_PORT + 6, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_tls( + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, params, domain_host): + """Test listen/connect tls().""" + transport_server.setup_tls( True, - "127.0.0.1", - 5101, + domain_host, + BASE_PORT + 7, None, - self.cwd + "crt", - self.cwd + "key", + params.cwd + "crt", + params.cwd + "key", None, "localhost", ) - assert await server.transport_listen() - assert server.transport + server = await transport_server.transport_listen() + assert server + + transport.setup_tcp(False, domain_host, BASE_PORT + 7) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_udp(self): + +class TestCommUdpTransport: + """Test for the transport module.""" + + async def test_connect(self): + """Test connect_udp().""" + # always true, since udp is connectionless. + + @pytest.mark.parametrize("positive", [True, False]) + async def test_listen(self, transport_server, positive, domain_host): """Test listen_udp().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_udp(True, "10.0.0.1", 5101) - assert not await server.transport_listen() - assert not server.transport - - server = self.dummy_transport(ModbusSocketFramer) - server.setup_udp(True, "localhost", 5101) - assert await server.transport_listen() - assert server.transport + transport_server.setup_udp(True, domain_host, BASE_PORT + 10) + server = await transport_server.transport_listen() + assert positive == bool(server) + assert positive == bool(transport_server.transport) + transport_server.close() + if server: + server.close() + + @pytest.mark.parametrize("positive", [True]) + async def test_connected(self, transport, transport_server, domain_host): + """Test listen/connect udp().""" + transport_server.setup_udp(True, domain_host, BASE_PORT + 11) + server = await transport_server.transport_listen() + assert server + transport.setup_udp(False, domain_host, BASE_PORT + 11) + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen_serial(self): - """Test listen_serial().""" - server = self.dummy_transport(ModbusSocketFramer) - server.setup_serial( - True, - "no port", + +class TestCommSerialTransport: + """Test for the transport module.""" + + @pytest.mark.parametrize("positive", [True, False]) + async def test_connect(self, transport, positive): + """Test connect_serial().""" + domain_port = ( + f"unix:/localhost:{BASE_PORT + 15}" if positive else "/illegal_port" + ) + transport.setup_serial( + False, + domain_port, 9600, 8, "E", 2, ) - assert not await server.transport_listen() - assert not server.transport - - # there are no positive test, since there are no standard tty port - - @pytest.mark.skipif( - pytest.IS_WINDOWS, reason="Windows do not support unix sockets." - ) - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_unix(self): - """Test listen/connect unix().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - domain_socket = gettempdir() + "/test_unix_" + str(time.time()) - server_protocol.setup_unix(True, domain_socket) - server = await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_unix(False, domain_socket) - assert await client.transport_connect() != (None, None) - server_protocol.comm_params.comm_name = "jan server" - client.comm_params.comm_name = "jan client" - client.close() - server_protocol.close() - server.close() + start = time.time() + assert await transport.transport_connect() == (None, None) + delta = time.time() - start + assert delta < transport.comm_params.timeout_connect * 1.2 + transport.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_tcp(self): - """Test listen/connect tcp().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tcp(True, "localhost", 5101) - assert await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_tls(self): - """Test listen/connect tls().""" - self.setup_CWD() - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tls( + async def test_listen(self, transport_server): + """Test listen_serial().""" + transport_server.setup_serial( True, - "127.0.0.1", - 5102, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", + "/illegal_port", + 9600, + 8, + "E", + 2, ) - assert await server_protocol.transport_listen() + server = await transport_server.transport_listen() + assert not server + assert not transport_server.transport + transport_server.close() - client = self.dummy_transport(ModbusSocketFramer) - client.setup_tls( - False, - "127.0.0.1", - 5102, - None, - self.cwd + "crt", - self.cwd + "key", - None, - "localhost", - ) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() + # there are no positive test, since there are no standard tty port - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_udp(self): - """Test listen/connect udp().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_udp(True, "localhost", 5101) - transport = await server_protocol.transport_listen() - assert transport - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_udp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - client.close() - server_protocol.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected_serial(self): + async def test_connected(self, transport, transport_server): """Test listen/connect serial().""" - server_protocol = self.dummy_transport(ModbusSocketFramer) - server_protocol.setup_tcp(True, "localhost", 5101) - assert await server_protocol.transport_listen() - - client = self.dummy_transport(ModbusSocketFramer) - client.setup_serial( + transport_server.setup_tcp(True, "localhost", BASE_PORT + 16) + server = await transport_server.transport_listen() + assert server + transport.setup_serial( False, - "unix:localhost:5001", + f"socket://localhost:{BASE_PORT + 16}", 9600, 8, "E", 2, ) - assert await client.transport_connect() == (None, None) - client.close() - server_protocol.close() - - @pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows 3.8 problem.") - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect_reconnect(self): - """Test connect() reconnecting.""" - server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode") - server.setup_tcp(True, "localhost", 5101) - await server.transport_listen() - assert server.transport - - client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode") - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - server.close() - count = 500 - while client.transport and count: - await asyncio.sleep(0.1) - count -= 1 - assert not client.transport - assert client.reconnect_timer - client.close() + assert await transport.transport_connect() != (None, None) + transport.close() + transport_server.close() server.close() diff --git a/test/transport/test_data.py b/test/transport/test_data.py new file mode 100644 index 000000000..a48333f26 --- /dev/null +++ b/test/transport/test_data.py @@ -0,0 +1,27 @@ +"""Test transport.""" +import asyncio + + +BASE_PORT = 5260 + + +class TestDataTransport: # pylint: disable=too-few-public-methods + """Test for the transport module.""" + + async def test_client_send(self, transport, transport_server): + """Test send().""" + transport_server.setup_tcp(True, "localhost", BASE_PORT + 1) + server = await transport_server.transport_listen() + assert transport_server.transport + + transport.setup_tcp(False, "localhost", BASE_PORT + 1) + assert await transport.transport_connect() != (None, None) + await transport.send(b"ABC") + await asyncio.sleep(2) + assert transport_server.recv_buffer == b"ABC" + await transport_server.send(b"DEF") + await asyncio.sleep(2) + assert transport.recv_buffer == b"DEF" + transport.close() + transport_server.close() + server.close() diff --git a/test/transport/test_reconnect.py b/test/transport/test_reconnect.py index c19c4017e..12943948b 100644 --- a/test/transport/test_reconnect.py +++ b/test/transport/test_reconnect.py @@ -2,128 +2,63 @@ import asyncio from unittest import mock -from pymodbus.framer import ModbusFramer -from pymodbus.transport.transport import BaseTransport +BASE_PORT = 5250 -class TestBaseTransport: - """Test transport module, base part.""" - - base_comm_name = "test comm" - base_reconnect_delay = 1.0 - base_reconnect_delay_max = 7.5 - base_timeout_connect = 2.0 - base_framer = ModbusFramer - base_host = "test host" - base_port = 502 - base_server_hostname = "server test host" - base_baudrate = 9600 - base_bytesize = 8 - base_parity = "e" - base_stopbits = 2 - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def __init__(self): - """Initialize.""" - super().__init__( - TestBaseTransport.base_comm_name, - [ - TestBaseTransport.base_reconnect_delay * 1000, - TestBaseTransport.base_reconnect_delay_max * 1000, - ], - TestBaseTransport.base_timeout_connect * 1000, - TestBaseTransport.base_framer, - None, - None, - None, - ) - self.abort = mock.MagicMock() - self.close = mock.MagicMock() - @classmethod - async def setup_BaseTransport(cls): - """Create base object.""" - base = BaseTransport( - cls.base_comm_name, - (cls.base_reconnect_delay * 1000, cls.base_reconnect_delay_max * 1000), - cls.base_timeout_connect * 1000, - cls.base_framer, - mock.MagicMock(), - mock.MagicMock(), - mock.MagicMock(), - ) - params = base.CommParamsClass( - done=True, - comm_name=cls.base_comm_name, - reconnect_delay=cls.base_reconnect_delay, - reconnect_delay_max=cls.base_reconnect_delay_max, - timeout_connect=cls.base_timeout_connect, - framer=cls.base_framer, - ) - return base, params +class TestReconnectTransport: + """Test transport module, base part.""" - async def test_no_reconnect_call(self): + async def test_no_reconnect_call(self, transport, commparams): """Test connection_lost().""" - transport, _params = await self.setup_BaseTransport() - transport.setup_tcp(False, self.base_host, self.base_port) + transport.setup_tcp(False, "localhost", BASE_PORT + 1) transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) transport.connection_made(mock.Mock()) assert not transport.call_connect_listen.call_count - assert transport.reconnect_delay_current == self.base_reconnect_delay - + assert transport.reconnect_delay_current == commparams.reconnect_delay transport.connection_lost(RuntimeError("Connection lost")) assert not transport.call_connect_listen.call_count - assert transport.reconnect_delay_current == self.base_reconnect_delay + assert transport.reconnect_delay_current == commparams.reconnect_delay transport.close() - async def test_reconnect_call(self): + async def test_reconnect_call(self, transport, commparams): """Test connection_lost().""" - transport, _params = await self.setup_BaseTransport() - transport.setup_tcp(False, self.base_host, self.base_port) + transport.setup_tcp(False, "localhost", BASE_PORT + 2) transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) transport.connection_made(mock.Mock()) transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.2) assert transport.call_connect_listen.call_count == 1 - assert transport.reconnect_delay_current == self.base_reconnect_delay * 2 + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 transport.close() - async def test_multi_reconnect_call(self): + async def test_multi_reconnect_call(self, transport, commparams): """Test connection_lost().""" - transport, _params = await self.setup_BaseTransport() - transport.setup_tcp(False, self.base_host, self.base_port) + transport.setup_tcp(False, "localhost", BASE_PORT + 3) transport.call_connect_listen = mock.AsyncMock(return_value=(None, None)) transport.connection_made(mock.Mock()) transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.2) assert transport.call_connect_listen.call_count == 1 - assert transport.reconnect_delay_current == self.base_reconnect_delay * 2 - + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 await asyncio.sleep(transport.reconnect_delay_current * 1.2) assert transport.call_connect_listen.call_count == 2 - assert transport.reconnect_delay_current == self.base_reconnect_delay * 4 - + assert transport.reconnect_delay_current == commparams.reconnect_delay_max await asyncio.sleep(transport.reconnect_delay_current * 1.2) assert transport.call_connect_listen.call_count == 3 - assert transport.reconnect_delay_current == self.base_reconnect_delay_max + assert transport.reconnect_delay_current == commparams.reconnect_delay_max transport.close() - async def test_reconnect_call_ok(self): + async def test_reconnect_call_ok(self, transport, commparams): """Test connection_lost().""" - transport, _params = await self.setup_BaseTransport() - transport.setup_tcp(False, self.base_host, self.base_port) + transport.setup_tcp(False, "localhost", BASE_PORT + 4) transport.call_connect_listen = mock.AsyncMock( return_value=(mock.Mock(), mock.Mock()) ) transport.connection_made(mock.Mock()) transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.2) assert transport.call_connect_listen.call_count == 1 - assert transport.reconnect_delay_current == self.base_reconnect_delay * 2 - assert not transport.reconnect_timer + assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 + assert not transport.reconnect_task transport.close() diff --git a/test/transport/xtest_data.py b/test/transport/xtest_data.py deleted file mode 100644 index 035e3afb9..000000000 --- a/test/transport/xtest_data.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Test transport.""" -import asyncio - -import pytest - -from pymodbus.framer import ModbusFramer, ModbusSocketFramer -from pymodbus.transport.transport import BaseTransport - - -class TestDataTransport: - """Test for the transport module.""" - - class dummy_transport(BaseTransport): - """Transport class for test.""" - - def cb_connection_made(self): - """Handle callback.""" - - def cb_connection_lost(self, _exc): - """Handle callback.""" - - def cb_handle_data(self, _data): - """Handle callback.""" - return 0 - - def __init__(self, framer: ModbusFramer, comm_name="test comm"): - """Initialize.""" - super().__init__( - comm_name, - [2500, 9000], - 2000, - framer, - self.cb_connection_made, - self.cb_connection_lost, - self.cb_handle_data, - ) - - @pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows problem.") - @pytest.mark.xdist_group(name="server_serialize") - async def test_client_send(self): - """Test connect() reconnecting.""" - server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode") - server.setup_tcp(True, "localhost", 5101) - await server.transport_listen() - assert server.transport - - client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode") - client.setup_tcp(False, "localhost", 5101) - assert await client.transport_connect() != (None, None) - await client.send(b"ABC") - await asyncio.sleep(2) - assert server.recv_buffer == b"ABC" - await server.send(b"DEF") - await asyncio.sleep(2) - assert client.recv_buffer == b"DEF"