Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transport fixes and 100% test coverage. #1580

Merged
merged 1 commit into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import DictTransactionManager
from pymodbus.transport import BaseTransport
from pymodbus.transport.transport import Transport
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin, BaseTransport):
class ModbusBaseClient(ModbusClientMixin, Transport):
"""**ModbusBaseClient**

**Parameters common to all clients**:
Expand Down Expand Up @@ -94,12 +94,12 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs: Any,
) -> None:
"""Initialize a client instance."""
BaseTransport.__init__(
Transport.__init__(
self,
"comm",
(reconnect_delay * 1000, reconnect_delay_max * 1000),
reconnect_delay * 1000,
reconnect_delay_max * 1000,
timeout * 1000,
framer,
lambda: None,
self.cb_base_connection_lost,
self.cb_base_handle_data,
Expand Down
6 changes: 0 additions & 6 deletions pymodbus/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1 @@
"""Transport."""

__all__ = [
"BaseTransport",
]

from pymodbus.transport.transport import BaseTransport
43 changes: 23 additions & 20 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,22 @@
from dataclasses import dataclass
from typing import Any, Callable, Coroutine

from pymodbus.framer import ModbusFramer
from pymodbus.logging import Log
from pymodbus.transport.serial_asyncio import create_serial_connection


class BaseTransport:
"""Base class for transport types.
class Transport:
"""Transport layer.

BaseTransport contains functions common to all transport types and client/server.
Contains pure transport methods needed to connect/listen, send/receive and close connections
for unix socket, tcp, tls and serial communications.

This class is not available in the pymodbus API, and should not be referenced in Applications.
Contains high level methods like reconnect.

This class is not available in the pymodbus API, and should not be referenced in Applications
nor in the pymodbus documentation.

The class is designed to be an object in the message level class.
"""

@dataclass
Expand All @@ -33,7 +38,6 @@ class CommParamsClass:
reconnect_delay: float = None
reconnect_delay_max: float = None
timeout_connect: float = None
framer: ModbusFramer = None

# tcp / tls / udp / serial
host: str = None
Expand All @@ -60,19 +64,19 @@ def check_done(self):
def __init__(
self,
comm_name: str,
reconnect_delay: tuple[int, int],
reconnect_delay: int,
reconnect_max: int,
timeout_connect: int,
framer: ModbusFramer,
callback_connected: Callable[[], None],
callback_disconnected: Callable[[Exception], None],
callback_data: Callable[[bytes], int],
) -> None:
"""Initialize a transport instance.

:param comm_name: name of this transport connection
:param reconnect_delay: delay and max in milliseconds for first reconnect (0,0 for no reconnect)
:param reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect)
:param reconnect_delay: max reconnect delay in milliseconds
:param timeout_connect: Max. time in milliseconds for connect to complete
:param framer: Modbus framer to decode/encode messagees.
:param callback_connected: Called when connection is established
:param callback_disconnected: Called when connection is disconnected
:param callback_data: Called when data is received
Expand All @@ -84,19 +88,18 @@ def __init__(
# properties, can be read, but may not be mingled with
self.comm_params = self.CommParamsClass(
comm_name=comm_name,
reconnect_delay=reconnect_delay[0] / 1000,
reconnect_delay_max=reconnect_delay[1] / 1000,
reconnect_delay=reconnect_delay / 1000,
reconnect_delay_max=reconnect_max / 1000,
timeout_connect=timeout_connect / 1000,
framer=framer,
)

self.reconnect_delay_current: float = 0
self.reconnect_delay_current: float = 0.0
self.transport: asyncio.BaseTransport | asyncio.Server = None
self.protocol: asyncio.BaseProtocol = None
self.loop: asyncio.AbstractEventLoop = None
with suppress(RuntimeError):
self.loop = asyncio.get_running_loop()
self.reconnect_timer: asyncio.Task = None
self.reconnect_task: asyncio.Task = None
self.recv_buffer: bytes = b""
self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None
self.use_udp = False
Expand Down Expand Up @@ -314,7 +317,7 @@ def connection_lost(self, reason: Exception):
self.cb_connection_lost(reason)
if self.transport:
self.close()
self.reconnect_timer = asyncio.create_task(self.reconnect_connect())
self.reconnect_task = asyncio.create_task(self.reconnect_connect())

def eof_received(self):
"""Call when eof received (other end closed connection).
Expand Down Expand Up @@ -360,9 +363,9 @@ def close(self, reconnect: bool = False) -> None:
self.transport.close()
self.transport = None
self.protocol = None
if not reconnect and self.reconnect_timer:
self.reconnect_timer.cancel()
self.reconnect_timer = None
if not reconnect and self.reconnect_task:
self.reconnect_task.cancel()
self.reconnect_task = None
self.recv_buffer = b""

def reset_delay(self) -> None:
Expand Down Expand Up @@ -395,7 +398,7 @@ async def reconnect_connect(self):
)
except asyncio.CancelledError:
pass
self.reconnect_timer = None
self.reconnect_task = None

# ----------------- #
# The magic methods #
Expand Down
1 change: 1 addition & 0 deletions test/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test of transport layer."""
88 changes: 88 additions & 0 deletions test/transport/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Test transport."""
import os
from dataclasses import dataclass
from unittest import mock

import pytest
import pytest_asyncio

from pymodbus.transport.transport import Transport


@dataclass
class BaseParams(Transport.CommParamsClass):
"""Base parameters for all transport testing."""

comm_name = "test comm"
reconnect_delay = 1000
reconnect_delay_max = 3500
timeout_connect = 2000
host = "test host"
port = 502
server_hostname = "server test host"
baudrate = 9600
bytesize = 8
parity = "e"
stopbits = 2
cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus."


@pytest.fixture(name="params")
def prepare_baseparams():
"""Prepare BaseParams class."""
return BaseParams


class DummySocket: # pylint: disable=too-few-public-methods
"""Socket simulator for test."""

def __init__(self):
"""Initialize."""
self.close = mock.Mock()
self.abort = mock.Mock()


@pytest.fixture(name="dummy_socket")
def prepare_dummysocket():
"""Prepare dummy_socket class."""
return DummySocket


@pytest.fixture(name="commparams")
def prepare_testparams():
"""Prepare CommParamsClass object."""
return Transport.CommParamsClass(
done=True,
comm_name=BaseParams.comm_name,
reconnect_delay=BaseParams.reconnect_delay / 1000,
reconnect_delay_max=BaseParams.reconnect_delay_max / 1000,
timeout_connect=BaseParams.timeout_connect / 1000,
)


@pytest.fixture(name="transport")
async def prepare_transport():
"""Prepare transport object."""
return Transport(
BaseParams.comm_name,
BaseParams.reconnect_delay,
BaseParams.reconnect_delay_max,
BaseParams.timeout_connect,
mock.Mock(name="cb_connection_made"),
mock.Mock(name="cb_connection_lost"),
mock.Mock(name="cb_handle_data", return_value=0),
)


@pytest_asyncio.fixture(name="transport_server")
async def prepare_transport_server():
"""Prepare transport object."""
return Transport(
BaseParams.comm_name,
BaseParams.reconnect_delay,
BaseParams.reconnect_delay_max,
BaseParams.timeout_connect,
mock.Mock(name="cb_connection_made"),
mock.Mock(name="cb_connection_lost"),
mock.Mock(name="cb_handle_data", return_value=0),
)
Loading