Skip to content

Commit

Permalink
Add sync transactionManager.
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Nov 17, 2024
1 parent 3378a73 commit 983230f
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 88 deletions.
2 changes: 1 addition & 1 deletion pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pymodbus.framer import FRAMER_NAME_TO_CLASS, FramerBase, FramerType
from pymodbus.logging import Log
from pymodbus.pdu import DecodePDU, ModbusPDU
from pymodbus.transaction import SyncModbusTransactionManager
from pymodbus.transaction.old_transaction import SyncModbusTransactionManager
from pymodbus.transport import CommParams
from pymodbus.utilities import ModbusTransactionState

Expand Down
6 changes: 0 additions & 6 deletions pymodbus/transaction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
"""Transaction."""
__all__ = [
"ModbusTransactionManager",
"SyncModbusTransactionManager",
"TransactionManager",
]

from pymodbus.transaction.old_transaction import (
ModbusTransactionManager,
SyncModbusTransactionManager,
)
from pymodbus.transaction.transaction import TransactionManager
81 changes: 68 additions & 13 deletions pymodbus/transaction/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
from collections.abc import Callable
from threading import RLock

from pymodbus.exceptions import ConnectionException, ModbusIOException
from pymodbus.framer import FramerBase
Expand Down Expand Up @@ -39,9 +40,10 @@ def __init__(
framer: FramerBase,
retries: int,
is_server: bool,
sync_client = None,
) -> None:
"""Initialize an instance of the ModbusTransactionManager."""
super().__init__(params, is_server)
super().__init__(params, is_server, is_sync=bool(sync_client))
self.framer = framer
self.retries = retries
self.next_tid: int = 0
Expand All @@ -51,14 +53,69 @@ def __init__(
self.trace_send_pdu: Callable[[ModbusPDU | None], ModbusPDU] | None = None
self.accept_no_response_limit = retries + 3
self.count_no_responses = 0
self._lock = asyncio.Lock()
if sync_client:
self.sync_client = sync_client
self._sync_lock = RLock()
else:
self._lock = asyncio.Lock()
self.response_future: asyncio.Future = asyncio.Future()

async def execute(self, no_response_expected: bool, request) -> ModbusPDU | None:
"""Execute requests asynchronously."""
def sync_get_response(self) -> ModbusPDU | None:
"""Receive until PDU is correct or timeout."""
databuffer = b''
while True:
if not (data := self.sync_client.recv(None)):
raise asyncio.exceptions.TimeoutError()
databuffer += data
used_len, pdu = self.framer.processIncomingFrame(databuffer)
databuffer = databuffer[used_len:]
if pdu:
return pdu

def sync_execute(self, no_response_expected: bool, request: ModbusPDU) -> ModbusPDU | None:
"""Execute requests asynchronously.
REMARK: this method is identical to execute, apart from the lock and sync_receive.
any changes in either method MUST be mirrored !!!
"""
if not self.transport:
Log.warning("Not connected, trying to connect!")
if not self.sync_client.connect():
raise ConnectionException("Client cannot connect (automatic retry continuing) !!")
with self._sync_lock:
request.transaction_id = self.getNextTID()
if self.trace_send_pdu:
request = self.trace_send_pdu(request) # pylint: disable=not-callable
packet = self.framer.buildFrame(request)
count_retries = 0
while count_retries <= self.retries:
if self.trace_send_packet:
packet = self.trace_send_packet(packet) # pylint: disable=not-callable
self.sync_client.send(packet)
if no_response_expected:
return None
try:
return self.sync_get_response()
except asyncio.exceptions.TimeoutError:
count_retries += 1
if self.count_no_responses >= self.accept_no_response_limit:
self.connection_lost(asyncio.TimeoutError("Server not responding"))
raise ModbusIOException(
f"ERROR: No response received of the last {self.accept_no_response_limit} request, CLOSING CONNECTION."
)
self.count_no_responses += 1
Log.error(f"No response received after {self.retries} retries, continue with next request")
return None

async def execute(self, no_response_expected: bool, request: ModbusPDU) -> ModbusPDU | None:
"""Execute requests asynchronously.
REMARK: this method is identical to sync_execute, apart from the lock and try/except.
any changes in either method MUST be mirrored !!!
"""
if not self.transport:
Log.warning("Not connected, trying to connect!")
if not self.transport and not await self.connect():
if not await self.connect():
raise ConnectionException("Client cannot connect (automatic retry continuing) !!")
async with self._lock:
request.transaction_id = self.getNextTID()
Expand Down Expand Up @@ -101,14 +158,12 @@ def callback_connected(self) -> None:

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
if self.trace_recv_packet:
self.trace_recv_packet(None) # pylint: disable=not-callable
if self.trace_recv_pdu:
self.trace_recv_pdu(None) # pylint: disable=not-callable
if self.trace_send_packet:
self.trace_send_packet(None) # pylint: disable=not-callable
if self.trace_send_pdu:
self.trace_send_pdu(None) # pylint: disable=not-callable
for call in (self.trace_recv_packet,
self.trace_recv_pdu,
self.trace_send_packet,
self.trace_send_pdu):
if call:
call(None) # pylint: disable=not-callable

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
Expand Down
8 changes: 7 additions & 1 deletion pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,24 @@ def __init__(
self,
params: CommParams,
is_server: bool,
is_sync: bool = False
) -> None:
"""Initialize a transport instance.
:param params: parameter dataclass
:param is_server: true if object act as a server (listen/connect)
:param is_sync: true if used with sync client
"""
self.comm_params = params.copy()
self.is_server = is_server
self.is_closing = False

self.transport: asyncio.BaseTransport = None # type: ignore[assignment]
self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
self.loop: asyncio.AbstractEventLoop
if is_sync:
self.loop = asyncio.new_event_loop()
else:
self.loop = asyncio.get_running_loop()
self.recv_buffer: bytes = b""
self.call_create: Callable[[], Coroutine[Any, Any, Any]] = None # type: ignore[assignment]
self.reconnect_task: asyncio.Task | None = None
Expand Down
57 changes: 1 addition & 56 deletions test/transaction/test_old_transaction.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
"""Test transaction."""
from unittest import mock

from pymodbus.client import ModbusTcpClient
from pymodbus.exceptions import (
ModbusIOException,
)
from pymodbus.framer import (
FramerAscii,
FramerRTU,
FramerSocket,
FramerTLS,
)
from pymodbus.pdu import DecodePDU, ModbusPDU
from pymodbus.transaction import (
ModbusTransactionManager,
SyncModbusTransactionManager,
)
from pymodbus.transaction.old_transaction import SyncModbusTransactionManager


TEST_MESSAGE = b"\x7b\x01\x03\x00\x00\x00\x05\x85\xC9\x7d"
Expand Down Expand Up @@ -51,54 +44,6 @@ def setup_method(self):
# Modbus transaction manager
# ----------------------------------------------------------------------- #

@mock.patch.object(ModbusTransactionManager, "getTransaction")
def test_execute(self, mock_get_transaction):
"""Test execute."""
client = ModbusTcpClient("localhost")
client.recv = mock.Mock()
client.framer = self._ascii
client.framer._buffer = b"deadbeef" # pylint: disable=protected-access
client.framer.processIncomingFrame = mock.MagicMock()
client.framer.processIncomingFrame.return_value = 0, None
client.framer.buildFrame = mock.MagicMock()
client.framer.buildFrame.return_value = b"deadbeef"
client.send = mock.MagicMock()
client.send.return_value = len(b"deadbeef")
request = mock.MagicMock()
request.get_response_pdu_size.return_value = 10
request.slave_id = 1
request.function_code = 222
trans = SyncModbusTransactionManager(client, 3)
assert trans.retries == 3

client.recv.side_effect=iter([b"abcdef", None])
mock_get_transaction.return_value = b"response"
trans.retries = 0
response = trans.execute(False, request)
assert isinstance(response, ModbusIOException)
# No response
client.recv.side_effect=iter([b"abcdef", None])
trans.transactions = {}
mock_get_transaction.return_value = None
response = trans.execute(False, request)
assert isinstance(response, ModbusIOException)

# No response with retries
client.recv.side_effect=iter([b"", b"abcdef"])
response = trans.execute(False, request)
assert isinstance(response, ModbusIOException)

# wrong handle_local_echo
client.recv.side_effect=iter([b"abcdef", b"deadbe", b"123456"])
client.comm_params.handle_local_echo = True
assert trans.execute(False, request).message == "[Input/Output] SEND failed"
client.comm_params.handle_local_echo = False

# retry on invalid response
client.recv.side_effect=iter([b"", b"abcdef", b"deadbe", b"123456"])
response = trans.execute(False, request)
assert isinstance(response, ModbusIOException)

def test_transaction_manager_tid(self):
"""Test the transaction manager TID."""
for tid in range(1, self._manager.getNextTID() + 10):
Expand Down
Loading

0 comments on commit 983230f

Please sign in to comment.