Skip to content

Commit

Permalink
Replace lambda with functools.partial in transport. (#2047)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen authored Feb 22, 2024
1 parent 5c4fa10 commit bb9ce24
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 42 deletions.
13 changes: 7 additions & 6 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import ssl
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import Any, Callable, Coroutine

from pymodbus.logging import Log
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
if self.comm_params.comm_type == CommType.SERIAL and NULLMODEM_HOST in host:
host, port = NULLMODEM_HOST, int(host[9:].split(":")[1])
if host == NULLMODEM_HOST:
self.call_create = lambda: self.create_nullmodem(port)
self.call_create = partial(self.create_nullmodem, port)
return
if (
self.comm_params.comm_type == CommType.SERIAL
Expand All @@ -191,7 +192,7 @@ def __init__(
def init_setup_connect_listen(self, host: str, port: int) -> None:
"""Handle connect/listen handler."""
if self.comm_params.comm_type == CommType.SERIAL:
self.call_create = lambda: create_serial_connection( # pragma: no cover
self.call_create = partial(create_serial_connection,
self.loop,
self.handle_new_connection,
host,
Expand All @@ -205,19 +206,19 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
return
if self.comm_params.comm_type == CommType.UDP:
if self.is_server:
self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover
self.call_create = partial(self.loop.create_datagram_endpoint,
self.handle_new_connection,
local_addr=(host, port),
)
else:
self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover
self.call_create = partial(self.loop.create_datagram_endpoint,
self.handle_new_connection,
remote_addr=(host, port),
)
return
# TLS and TCP
if self.is_server:
self.call_create = lambda: self.loop.create_server( # pragma: no cover
self.call_create = partial(self.loop.create_server,
self.handle_new_connection,
host,
port,
Expand All @@ -226,7 +227,7 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
start_serving=True,
)
else:
self.call_create = lambda: self.loop.create_connection( # pragma: no cover
self.call_create = partial(self.loop.create_connection,
self.handle_new_connection,
host,
port,
Expand Down
20 changes: 10 additions & 10 deletions test/transport/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ async def prepare_dummy_protocol():
@pytest.fixture(name="client")
async def prepare_protocol(use_clc):
"""Prepare transport object."""
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_clc.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_clc.generate_ssl(
use_clc.sslctx = use_clc.generate_ssl(
False, certfile=cwd + "crt", keyfile=cwd + "key"
)
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_clc.comm_type == CommType.SERIAL:
transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}"
return transport
Expand All @@ -68,13 +68,13 @@ async def prepare_protocol(use_clc):
@pytest.fixture(name="server")
async def prepare_transport_server(use_cls):
"""Prepare transport object."""
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_cls.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_cls.generate_ssl(
use_cls.sslctx = use_cls.generate_ssl(
True, certfile=cwd + "crt", keyfile=cwd + "key"
)
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
return transport
35 changes: 9 additions & 26 deletions test/transport/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ async def test_external_methods(self, inx):
]
methods[inx]()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_create_serial(self):
"""Test external methods."""
if os.name == "nt":
return

transport, protocol = await create_serial_connection(
asyncio.get_running_loop(), mock.Mock, url="dummy"
)
Expand Down Expand Up @@ -110,31 +108,25 @@ async def test_close(self):
comm.sync_serial = None
comm.close()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_polling(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = asyncio.CancelledError("test")
await comm.polling_task()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_poll_task(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = serial.SerialException("test")
await comm.polling_task()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_poll_task2(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial = mock.MagicMock()
Expand All @@ -144,57 +136,48 @@ async def test_poll_task2(self):
await comm.polling_task()


@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_exception(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.side_effect = BlockingIOError("test")
comm.intern_write_ready()
comm.sync_serial.write.side_effect = serial.SerialException("test")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_ok(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 4
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_len(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 3
comm.async_loop.add_writer = mock.Mock()
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_force(self):
"""Test write exception."""
if os.name == "nt":
return
comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.poll_task = True
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 3
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_read_ready(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.intern_protocol = mock.Mock()
Expand Down

0 comments on commit bb9ce24

Please sign in to comment.