Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace lambda with functools.partial in transport. #2047

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import ssl
from contextlib import suppress
from enum import Enum
from functools import partial
from typing import Any, Callable, Coroutine

from pymodbus.logging import Log
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
if self.comm_params.comm_type == CommType.SERIAL and NULLMODEM_HOST in host:
host, port = NULLMODEM_HOST, int(host[9:].split(":")[1])
if host == NULLMODEM_HOST:
self.call_create = lambda: self.create_nullmodem(port)
self.call_create = partial(self.create_nullmodem, port)
return
if (
self.comm_params.comm_type == CommType.SERIAL
Expand All @@ -191,7 +192,7 @@ def __init__(
def init_setup_connect_listen(self, host: str, port: int) -> None:
"""Handle connect/listen handler."""
if self.comm_params.comm_type == CommType.SERIAL:
self.call_create = lambda: create_serial_connection( # pragma: no cover
self.call_create = partial(create_serial_connection,
self.loop,
self.handle_new_connection,
host,
Expand All @@ -205,19 +206,19 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
return
if self.comm_params.comm_type == CommType.UDP:
if self.is_server:
self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover
self.call_create = partial(self.loop.create_datagram_endpoint,
self.handle_new_connection,
local_addr=(host, port),
)
else:
self.call_create = lambda: self.loop.create_datagram_endpoint( # pragma: no cover
self.call_create = partial(self.loop.create_datagram_endpoint,
self.handle_new_connection,
remote_addr=(host, port),
)
return
# TLS and TCP
if self.is_server:
self.call_create = lambda: self.loop.create_server( # pragma: no cover
self.call_create = partial(self.loop.create_server,
self.handle_new_connection,
host,
port,
Expand All @@ -226,7 +227,7 @@ def init_setup_connect_listen(self, host: str, port: int) -> None:
start_serving=True,
)
else:
self.call_create = lambda: self.loop.create_connection( # pragma: no cover
self.call_create = partial(self.loop.create_connection,
self.handle_new_connection,
host,
port,
Expand Down
20 changes: 10 additions & 10 deletions test/transport/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ async def prepare_dummy_protocol():
@pytest.fixture(name="client")
async def prepare_protocol(use_clc):
"""Prepare transport object."""
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_clc.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_clc.generate_ssl(
use_clc.sslctx = use_clc.generate_ssl(
False, certfile=cwd + "crt", keyfile=cwd + "key"
)
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_clc.comm_type == CommType.SERIAL:
transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}"
return transport
Expand All @@ -68,13 +68,13 @@ async def prepare_protocol(use_clc):
@pytest.fixture(name="server")
async def prepare_transport_server(use_cls):
"""Prepare transport object."""
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_cls.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_cls.generate_ssl(
use_cls.sslctx = use_cls.generate_ssl(
True, certfile=cwd + "crt", keyfile=cwd + "key"
)
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
return transport
35 changes: 9 additions & 26 deletions test/transport/test_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ async def test_external_methods(self, inx):
]
methods[inx]()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_create_serial(self):
"""Test external methods."""
if os.name == "nt":
return

transport, protocol = await create_serial_connection(
asyncio.get_running_loop(), mock.Mock, url="dummy"
)
Expand Down Expand Up @@ -110,31 +108,25 @@ async def test_close(self):
comm.sync_serial = None
comm.close()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_polling(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = asyncio.CancelledError("test")
await comm.polling_task()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_poll_task(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.read.side_effect = serial.SerialException("test")
await comm.polling_task()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_poll_task2(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial = mock.MagicMock()
Expand All @@ -144,57 +136,48 @@ async def test_poll_task2(self):
await comm.polling_task()


@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_exception(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.side_effect = BlockingIOError("test")
comm.intern_write_ready()
comm.sync_serial.write.side_effect = serial.SerialException("test")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_ok(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 4
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_len(self):
"""Test write exception."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 3
comm.async_loop.add_writer = mock.Mock()
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_write_force(self):
"""Test write exception."""
if os.name == "nt":
return
comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.poll_task = True
comm.sync_serial = mock.MagicMock()
comm.sync_serial.write.return_value = 3
comm.intern_write_buffer.append(b"abcd")
comm.intern_write_ready()

@pytest.mark.skipif(os.name == "nt", reason="Windows not supported")
async def test_read_ready(self):
"""Test polling."""
if os.name == "nt":
return

comm = SerialTransport(asyncio.get_running_loop(), mock.Mock(), "dummy")
comm.sync_serial = mock.MagicMock()
comm.intern_protocol = mock.Mock()
Expand Down