diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index 51b456433..7a808f8c2 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -53,6 +53,7 @@ import ssl from contextlib import suppress from enum import Enum +from functools import partial from typing import Any, Callable, Coroutine from pymodbus.logging import Log @@ -175,7 +176,7 @@ def __init__( if self.comm_params.comm_type == CommType.SERIAL and NULLMODEM_HOST in host: host, port = NULLMODEM_HOST, int(host[9:].split(":")[1]) if host == NULLMODEM_HOST: - self.call_create = lambda: self.create_nullmodem(port) + self.call_create = partial(self.create_nullmodem, port) return if ( self.comm_params.comm_type == CommType.SERIAL @@ -191,7 +192,7 @@ def __init__( def init_setup_connect_listen(self, host: str, port: int) -> None: """Handle connect/listen handler.""" if self.comm_params.comm_type == CommType.SERIAL: - self.call_create = lambda: create_serial_connection( # pragma: no cover + self.call_create = partial(create_serial_connection, self.loop, self.handle_new_connection, host, @@ -205,19 +206,19 @@ def init_setup_connect_listen(self, host: str, port: int) -> None: return if self.comm_params.comm_type == CommType.UDP: if self.is_server: - self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover + self.call_create = partial(self.loop.create_datagram_endpoint, self.handle_new_connection, local_addr=(host, port), ) else: - self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover + self.call_create = partial(self.loop.create_datagram_endpoint, self.handle_new_connection, remote_addr=(host, port), ) return # TLS and TCP if self.is_server: - self.call_create = lambda: self.loop.create_server( # pragma: no cover + self.call_create = partial(self.loop.create_server, self.handle_new_connection, host, port, @@ -226,7 +227,7 @@ def init_setup_connect_listen(self, host: str, port: int) -> None: start_serving=True, ) else: - self.call_create = lambda: self.loop.create_connection( # pragma: no cover + self.call_create = partial(self.loop.create_connection, self.handle_new_connection, host, port, diff --git a/test/transport/conftest.py b/test/transport/conftest.py index 3b0d1bf49..28a4fa98f 100644 --- a/test/transport/conftest.py +++ b/test/transport/conftest.py @@ -51,15 +51,15 @@ async def prepare_dummy_protocol(): @pytest.fixture(name="client") async def prepare_protocol(use_clc): """Prepare transport object.""" - transport = ModbusProtocol(use_clc, False) - transport.callback_connected = mock.Mock() - transport.callback_disconnected = mock.Mock() - transport.callback_data = mock.Mock(return_value=0) if use_clc.comm_type == CommType.TLS: cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." - transport.comm_params.sslctx = use_clc.generate_ssl( + use_clc.sslctx = use_clc.generate_ssl( False, certfile=cwd + "crt", keyfile=cwd + "key" ) + transport = ModbusProtocol(use_clc, False) + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) if use_clc.comm_type == CommType.SERIAL: transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}" return transport @@ -68,13 +68,13 @@ async def prepare_protocol(use_clc): @pytest.fixture(name="server") async def prepare_transport_server(use_cls): """Prepare transport object.""" - transport = ModbusProtocol(use_cls, True) - transport.callback_connected = mock.Mock() - transport.callback_disconnected = mock.Mock() - transport.callback_data = mock.Mock(return_value=0) if use_cls.comm_type == CommType.TLS: cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." - transport.comm_params.sslctx = use_cls.generate_ssl( + use_cls.sslctx = use_cls.generate_ssl( True, certfile=cwd + "crt", keyfile=cwd + "key" ) + transport = ModbusProtocol(use_cls, True) + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) return transport diff --git a/test/transport/test_serial.py b/test/transport/test_serial.py index 700645627..77a6c154c 100644 --- a/test/transport/test_serial.py +++ b/test/transport/test_serial.py @@ -68,11 +68,9 @@ async def test_external_methods(self, inx): ] methods[inx]() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_create_serial(self): """Test external methods.""" - if os.name == "nt": - return - transport, protocol = await create_serial_connection( asyncio.get_running_loop(), mock.Mock, url="dummy" ) @@ -110,31 +108,25 @@ async def test_close(self): comm.sync_serial = None comm.close() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_polling(self): """Test polling.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial.read.side_effect = asyncio.CancelledError("test") await comm.polling_task() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_poll_task(self): """Test polling.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial.read.side_effect = serial.SerialException("test") await comm.polling_task() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_poll_task2(self): """Test polling.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial = mock.MagicMock() @@ -144,11 +136,9 @@ async def test_poll_task2(self): await comm.polling_task() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_write_exception(self): """Test write exception.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial.write.side_effect = BlockingIOError("test") @@ -156,22 +146,18 @@ async def test_write_exception(self): comm.sync_serial.write.side_effect = serial.SerialException("test") comm.intern_write_ready() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_write_ok(self): """Test write exception.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial.write.return_value = 4 comm.intern_write_buffer.append(b"abcd") comm.intern_write_ready() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_write_len(self): """Test write exception.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.sync_serial.write.return_value = 3 @@ -179,10 +165,9 @@ async def test_write_len(self): comm.intern_write_buffer.append(b"abcd") comm.intern_write_ready() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_write_force(self): """Test write exception.""" - if os.name == "nt": - return comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.poll_task = True comm.sync_serial = mock.MagicMock() @@ -190,11 +175,9 @@ async def test_write_force(self): comm.intern_write_buffer.append(b"abcd") comm.intern_write_ready() + @pytest.mark.skipif(os.name == "nt", reason="Windows not supported") async def test_read_ready(self): """Test polling.""" - if os.name == "nt": - return - comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy") comm.sync_serial = mock.MagicMock() comm.intern_protocol = mock.Mock()