Skip to content

Commit

Permalink
New common transport layer. (#1492)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen authored Apr 20, 2023
1 parent 67d875b commit 7a4c609
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 344 deletions.
2 changes: 1 addition & 1 deletion examples/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def run_async_client(client, modbus_calls=None):
assert client.connected
if modbus_calls:
await modbus_calls(client)
await client.close()
client.close()
_logger.info("### End of Program")


Expand Down
88 changes: 39 additions & 49 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import DictTransactionManager
from pymodbus.transport import BaseTransport
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin):
class ModbusBaseClient(ModbusClientMixin, BaseTransport):
"""**ModbusBaseClient**
**Parameters common to all clients**:
Expand Down Expand Up @@ -63,8 +64,6 @@ class _params: # pylint: disable=too-many-instance-attributes
broadcast_enable: bool = None
kwargs: dict = None
reconnect_delay: int = None
reconnect_delay_max: int = None
on_reconnect_callback: Callable[[], None] | None = None

baudrate: int = None
bytesize: int = None
Expand Down Expand Up @@ -95,6 +94,7 @@ def __init__( # pylint: disable=too-many-arguments
**kwargs: Any,
) -> None:
"""Initialize a client instance."""
BaseTransport.__init__(self)
self.params = self._params()
self.params.framer = framer
self.params.timeout = float(timeout)
Expand All @@ -104,8 +104,8 @@ def __init__( # pylint: disable=too-many-arguments
self.params.strict = bool(strict)
self.params.broadcast_enable = bool(broadcast_enable)
self.params.reconnect_delay = int(reconnect_delay)
self.params.reconnect_delay_max = int(reconnect_delay_max)
self.params.on_reconnect_callback = on_reconnect_callback
self.reconnect_delay_max = int(reconnect_delay_max)
self.on_reconnect_callback = on_reconnect_callback
self.params.kwargs = kwargs

# Common variables.
Expand All @@ -115,15 +115,14 @@ def __init__( # pylint: disable=too-many-arguments
)
self.delay_ms = self.params.reconnect_delay
self.use_protocol = False
self._connected = False
self.use_udp = False
self.state = ModbusTransactionState.IDLE
self.last_frame_end: float = 0
self.silent_interval: float = 0
self.transport = None
self._reconnect_task = None

# Initialize mixin
super().__init__()
ModbusClientMixin.__init__(self)

# ----------------------------------------------------------------------- #
# Client external interface
Expand Down Expand Up @@ -174,30 +173,16 @@ def execute(self, request: ModbusRequest = None) -> ModbusResponse:
:raises ConnectionException: Check exception text.
"""
if self.use_protocol:
if not self._connected:
if not self.transport:
raise ConnectionException(f"Not connected[{str(self)}]")
return self.async_execute(request)
if not self.connect():
raise ConnectionException(f"Failed to connect[{str(self)}]")
return self.transaction.execute(request)

def close(self) -> None:
"""Close the underlying socket connection (call **sync/async**)."""
raise NotImplementedException

# ----------------------------------------------------------------------- #
# Merged client methods
# ----------------------------------------------------------------------- #
def client_made_connection(self, protocol):
"""Run transport specific connection."""

def client_lost_connection(self, protocol):
"""Run transport specific connection lost."""

def datagram_received(self, data, _addr):
"""Receive datagram."""
self.data_received(data)

async def async_execute(self, request=None):
"""Execute requests asynchronously."""
request.transaction_id = self.transaction.getNextTID()
Expand All @@ -218,29 +203,13 @@ async def async_execute(self, request=None):
raise
return resp

def connection_made(self, transport):
"""Call when a connection is made.
The transport argument is the transport representing the connection.
"""
self.transport = transport
Log.debug("Client connected to modbus server")
self._connected = True
self.client_made_connection(self)

def connection_lost(self, reason):
"""Call when the connection is lost or closed.
The argument is either an exception object or None
"""
if self.transport:
self.transport.abort()
if hasattr(self.transport, "_sock"):
self.transport._sock.close() # pylint: disable=protected-access
self.transport = None
self.client_lost_connection(self)
Log.debug("Client disconnected from modbus server: {}", reason)
self._connected = False
self.close(reconnect=True)
for tid in list(self.transaction):
self.raise_future(
self.transaction.getTransaction(tid),
Expand Down Expand Up @@ -277,22 +246,43 @@ def _handle_response(self, reply, **_kwargs):
def _build_response(self, tid):
"""Return a deferred response for the current request."""
my_future = self.create_future()
if not self._connected:
if not self.transport:
self.raise_future(my_future, ConnectionException("Client is not connected"))
else:
self.transaction.addTransaction(my_future, tid)
return my_future

@property
def async_connected(self):
"""Return connection status."""
return self._connected
def close(self, reconnect: bool = False) -> None:
"""Close connection.
async def async_close(self):
"""Close connection."""
:param reconnect: (default false), try to reconnect
"""
if self.transport:
if hasattr(self.transport, "_sock"):
self.transport._sock.close() # pylint: disable=protected-access
self.transport.abort()
self.transport.close()
self._connected = False
self.transport = None
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None

if not reconnect or not self.delay_ms:
self.delay_ms = 0
return

self._reconnect_task = asyncio.create_task(self._reconnect())

async def _reconnect(self):
"""Reconnect."""
Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms)
await asyncio.sleep(self.delay_ms / 1000)
self.delay_ms = min(2 * self.delay_ms, self.reconnect_delay_max)

self._reconnect_task = None
if self.on_reconnect_callback:
self.on_reconnect_callback()
return await self.connect()

# ----------------------------------------------------------------------- #
# Internal methods
Expand Down Expand Up @@ -353,7 +343,7 @@ def __exit__(self, klass, value, traceback):

async def __aexit__(self, klass, value, traceback):
"""Implement the client with exit block."""
await self.close()
self.close()

def __str__(self):
"""Build a string representation of the connection.
Expand Down
92 changes: 16 additions & 76 deletions pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def run():
await client.connect()
...
await client.close()
client.close()
"""

transport = None
Expand All @@ -68,25 +68,6 @@ def __init__(
self.params.parity = parity
self.params.stopbits = stopbits
self.params.handle_local_echo = handle_local_echo
self.loop = None
self._connected_event = asyncio.Event()
self._reconnect_task = None

async def close(self): # pylint: disable=invalid-overridden-method
"""Stop connection."""

# prevent reconnect:
self.delay_ms = 0
if self.connected:
if self.transport:
self.transport.close()
await self.async_close()
await asyncio.sleep(0.1)

# if there is an unfinished delayed reconnection attempt pending, cancel it
if self._reconnect_task:
self._reconnect_task.cancel()
self._reconnect_task = None

def _create_protocol(self):
"""Create a protocol instance."""
Expand All @@ -95,74 +76,33 @@ def _create_protocol(self):
@property
def connected(self):
"""Connect internal."""
return self._connected_event.is_set()
return self.transport is not None

async def connect(self): # pylint: disable=invalid-overridden-method
"""Connect Async client."""
# get current loop, if there are no loop a RuntimeError will be raised
self.loop = asyncio.get_running_loop()

Log.debug("Starting serial connection")
try:
await create_serial_connection(
self.loop,
self._create_protocol,
self.params.port,
baudrate=self.params.baudrate,
bytesize=self.params.bytesize,
stopbits=self.params.stopbits,
parity=self.params.parity,
await asyncio.wait_for(
create_serial_connection(
self.loop,
self._create_protocol,
self.params.port,
baudrate=self.params.baudrate,
bytesize=self.params.bytesize,
stopbits=self.params.stopbits,
parity=self.params.parity,
timeout=self.params.timeout,
**self.params.kwargs,
),
timeout=self.params.timeout,
**self.params.kwargs,
)
await self._connected_event.wait()
Log.info("Connected to {}", self.params.port)
except Exception as exc: # pylint: disable=broad-except
Log.warning("Failed to connect: {}", exc)
if self.delay_ms > 0:
self._launch_reconnect()
self.close(reconnect=True)
return self.connected

def client_made_connection(self, protocol):
"""Notify successful connection."""
Log.info("Serial connected.")
if not self.connected:
self._connected_event.set()
else:
Log.error("Factory protocol connect callback called while connected.")

def client_lost_connection(self, protocol):
"""Notify lost connection."""
Log.info("Serial lost connection.")
if protocol is not self:
Log.error("Serial: protocol is not self.")

self._connected_event.clear()
if self.delay_ms:
self._launch_reconnect()

def _launch_reconnect(self):
"""Launch delayed reconnection coroutine"""
if self._reconnect_task:
Log.warning(
"Ignoring launch of delayed reconnection, another is in progress"
)
else:
# store the future in a member variable so we know we have a pending reconnection attempt
# also prevents its garbage collection
self._reconnect_task = asyncio.create_task(self._reconnect())

async def _reconnect(self):
"""Reconnect."""
Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms)
await asyncio.sleep(self.delay_ms / 1000)
self.delay_ms = min(2 * self.delay_ms, self.params.reconnect_delay_max)

self._reconnect_task = None
if self.params.on_reconnect_callback:
self.params.on_reconnect_callback()
return await self.connect()


class ModbusSerialClient(ModbusBaseClient):
"""**ModbusSerialClient**.
Expand Down Expand Up @@ -267,7 +207,7 @@ def connect(self):
self.close()
return self.socket is not None

def close(self):
def close(self): # pylint: disable=arguments-differ
"""Close the underlying socket connection."""
if self.socket:
self.socket.close()
Expand Down
Loading

0 comments on commit 7a4c609

Please sign in to comment.