Skip to content

Commit

Permalink
Simplify syncTransactionManager.
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Nov 7, 2024
1 parent d137c6e commit 9aa8a07
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 204 deletions.
10 changes: 8 additions & 2 deletions pymodbus/framer/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,16 @@ class FramerTLS(FramerBase):
"""Modbus TLS frame type.
Layout::
[ Function Code] [ Data ]
1b Nb
[ MBAP Header ] [ Function Code] [ Data ]
[ tid ][ pid ][ length ][ uid ]
2b 2b 2b 1b 1b Nb
length = uid + function code + data
"""

MIN_SIZE = 8

def decode(self, data: bytes) -> tuple[int, int, int, bytes]:
"""Decode MDAP+PDU."""
tid = int.from_bytes(data[0:2], 'big')
Expand Down
228 changes: 68 additions & 160 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,22 +113,7 @@ def reset(self):


class SyncModbusTransactionManager(ModbusTransactionManager):
"""Implement a transaction for a manager.
The transaction protocol can be represented by the following pseudo code::
count = 0
do
result = send(message)
if (timeout or result == bad)
count++
else break
while (count < 3)
This module helps to abstract this away from the framer and protocol.
Results are keyed based on the supplied transaction id.
"""
"""Implement a transaction for a manager."""

def __init__(self, client: ModbusBaseSyncClient, retries):
"""Initialize an instance of the ModbusTransactionManager."""
Expand All @@ -137,44 +122,6 @@ def __init__(self, client: ModbusBaseSyncClient, retries):
self.retries = retries
self._transaction_lock = RLock()
self.databuffer = b''
if client:
self._set_adu_size()

def _set_adu_size(self):
"""Set adu size."""
# base ADU size of modbus frame in bytes
if isinstance(self.client.framer, FramerSocket):
self.base_adu_size = 7 # tid(2), pid(2), length(2), uid(1)
elif isinstance(self.client.framer, FramerRTU):
self.base_adu_size = 3 # address(1), CRC(2)
elif isinstance(self.client.framer, FramerAscii):
self.base_adu_size = 7 # start(1)+ Address(2), LRC(2) + end(2)
elif isinstance(self.client.framer, FramerTLS):
self.base_adu_size = 0 # no header and footer
else:
self.base_adu_size = -1

def _calculate_response_length(self, expected_pdu_size):
"""Calculate response length."""
if self.base_adu_size == -1:
return None
return self.base_adu_size + expected_pdu_size

def _calculate_exception_length(self):
"""Return the length of the Modbus Exception Response according to the type of Framer."""
if isinstance(self.client.framer, (FramerSocket, FramerTLS)):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
if isinstance(self.client.framer, FramerAscii):
return self.base_adu_size + 4 # Fcode(2), ExceptionCode(2)
if isinstance(self.client.framer, FramerRTU):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
return None

def _validate_response(self, response):
"""Validate Incoming response against request."""
if not response:
return False
return True

def execute(self, no_response_expected: bool, request: ModbusPDU): # noqa: C901
"""Start the producer to send the next request to consumer.write(Frame(request))."""
Expand All @@ -200,19 +147,11 @@ def execute(self, no_response_expected: bool, request: ModbusPDU): # noqa: C901
if isinstance(self.client.framer, FramerAscii):
response_pdu_size *= 2
if response_pdu_size:
expected_response_length = (
self._calculate_response_length(response_pdu_size)
)
full = False
if self.client.comm_params.comm_type == CommType.UDP:
full = True
if not expected_response_length:
expected_response_length = 1024
expected_response_length = self.client.framer.MIN_SIZE + response_pdu_size -1
response, last_exception = self._transact(
no_response_expected,
request,
expected_response_length,
full=full,
)
if no_response_expected:
return None
Expand Down Expand Up @@ -248,7 +187,7 @@ def execute(self, no_response_expected: bool, request: ModbusPDU): # noqa: C901
self.client.close()
return exc

def _retry_transaction(self, no_response_expected, retries, reason, packet, response_length, full=False):
def _retry_transaction(self, no_response_expected, retries, reason, packet, response_length):
"""Retry transaction."""
Log.debug("Retry on {} response - {}", reason, retries)
Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"')
Expand All @@ -259,128 +198,97 @@ def _retry_transaction(self, no_response_expected, retries, reason, packet, resp
in_waiting := self.client._in_waiting() # pylint: disable=protected-access
):
if response_length == in_waiting:
result = self._recv(response_length, full)
result = self._recv(response_length)
return result, None
return self._transact(no_response_expected, packet, response_length, full=full)
return self._transact(no_response_expected, packet, response_length)

def _transact(self, no_response_expected: bool, request: ModbusPDU, response_length, full=False):
def _transact(self, no_response_expected: bool, request: ModbusPDU, response_length):
"""Do a Write and Read transaction."""
last_exception = None
try:
self.client.connect()
packet = self.client.framer.buildFrame(request)
Log.debug("SEND: {}", packet, ":hex")
size = self._send(packet)
if (
isinstance(size, bytes)
and self.client.state == ModbusTransactionState.RETRYING
):
Log.debug(
"Changing transaction state from "
'"RETRYING" to "PROCESSING REPLY"'
)
self.client.state = ModbusTransactionState.PROCESSING_REPLY
return size, None
if self.client.comm_params.handle_local_echo is True:
if self._recv(size, full) != packet:
return b"", "Wrong local echo"
if (size := self.client.send(packet)) != len(packet):
return b"", f"Only sent {size} of {len(packet)} bytes"
if self.client.comm_params.handle_local_echo and self.client.recv(size) != packet:
return b"", "Wrong local echo"
if no_response_expected:
if size:
Log.debug(
'Changing transaction state from "SENDING" '
'to "TRANSACTION_COMPLETE"'
)
self.client.state = ModbusTransactionState.TRANSACTION_COMPLETE
return b"", None
if size:
Log.debug(
'Changing transaction state from "SENDING" '
'to "WAITING FOR REPLY"'
'to "TRANSACTION_COMPLETE" (no response expected)'
)
self.client.state = ModbusTransactionState.WAITING_FOR_REPLY
result = self._recv(response_length, full)
# result2 = self._recv(response_length, full)
self.client.state = ModbusTransactionState.TRANSACTION_COMPLETE
return b"", None
state = '"RETRYING"' if self.client.state == ModbusTransactionState.RETRYING else '"SENDING"'
Log.debug(f'Changing transaction state from {state} to "WAITING FOR REPLY"')
self.client.state = ModbusTransactionState.WAITING_FOR_REPLY
result = self._recv(response_length)
Log.debug("RECV: {}", result, ":hex")
return result, None
except (OSError, ModbusIOException, InvalidMessageReceivedException, ConnectionException) as msg:
self.client.close()
Log.debug("Transaction failed. ({}) ", msg)
last_exception = msg
result = b""
return result, last_exception

def _send(self, packet: bytes, _retrying=False):
"""Send."""
return self.client.send(packet)
return b"", msg

def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
def _recv(self, expected_response_length) -> bytes: # noqa: C901
"""Receive."""
if self.client.comm_params.comm_type == CommType.UDP:
read_min = self.client.recv(500)
else:
read_min = self.client.recv(self.client.framer.MIN_SIZE)
if (min_size := len(read_min)) < self.client.framer.MIN_SIZE:
msg_start = "Incomplete message" if read_min else "No response"
raise InvalidMessageReceivedException(
f"{msg_start} received, expected at least {self.client.framer.MIN_SIZE} bytes "
f"({min_size} received)"
)

if isinstance(self.client.framer, (FramerSocket, FramerTLS)):
func_code = int(read_min[self.client.framer.MIN_SIZE-1])
elif isinstance(self.client.framer, FramerRTU):
func_code = int(read_min[1])
elif isinstance(self.client.framer, FramerAscii):
func_code = int(read_min[3:5], 16)
else:
func_code = -1

total = None
if not full:
exception_length = self._calculate_exception_length()
if isinstance(self.client.framer, FramerSocket):
min_size = 8
elif isinstance(self.client.framer, FramerRTU):
min_size = 4
elif isinstance(self.client.framer, FramerAscii):
min_size = 5
else:
min_size = expected_response_length

read_min = self.client.recv(min_size)
if min_size and len(read_min) != min_size:
msg_start = "Incomplete message" if read_min else "No response"
raise InvalidMessageReceivedException(
f"{msg_start} received, expected at least {min_size} bytes "
f"({len(read_min)} received)"
)
if read_min:
if isinstance(self.client.framer, FramerSocket):
func_code = int(read_min[-1])
elif isinstance(self.client.framer, FramerRTU):
func_code = int(read_min[1])
elif isinstance(self.client.framer, FramerAscii):
func_code = int(read_min[3:5], 16)
else:
func_code = -1

if func_code < 0x80: # Not an error
if isinstance(self.client.framer, FramerSocket):
length = struct.unpack(">H", read_min[4:6])[0] - 1
expected_response_length = 7 + length
elif expected_response_length is None and isinstance(
self.client.framer, FramerRTU
):
with suppress(
IndexError # response length indeterminate with available bytes
):
expected_response_length = (
self._get_expected_response_length(
read_min
)
)
if expected_response_length is not None:
expected_response_length -= min_size
total = expected_response_length + min_size
else:
expected_response_length = exception_length - min_size
total = expected_response_length + min_size
if func_code < 0x80: # Not an error
if isinstance(self.client.framer, (FramerSocket, FramerTLS)):
length = struct.unpack(">H", read_min[4:6])[0] - 1
expected_response_length = 7 + length
elif expected_response_length is None and isinstance(
self.client.framer, FramerRTU
):
with suppress(
IndexError # response length indeterminate with available bytes
):
expected_response_length = (
self._get_expected_response_length(
read_min
)
)
if expected_response_length is not None:
expected_response_length -= min_size
total = expected_response_length + min_size
else:
if isinstance(self.client.framer, FramerAscii):
total = self.client.framer.MIN_SIZE + 2 # ExceptionCode(2)
else:
total = expected_response_length
total = self.client.framer.MIN_SIZE + 1 # ExceptionCode(1)
expected_response_length = total - min_size
result = read_min

if total and (missing_len := total - min_size):
retries = 0
missing_len = expected_response_length
result = read_min
while missing_len and retries < self.retries:
if retries:
time.sleep(0.1)
data = self.client.recv(expected_response_length)
data = self.client.recv(missing_len)
result += data
missing_len -= len(data)
retries += 1
else:
read_min = b""
total = expected_response_length
result = self.client.recv(expected_response_length)
result = read_min + result

actual = len(result)
if total is not None and actual != total:
msg_start = "Incomplete message" if actual else "No response"
Expand Down
45 changes: 3 additions & 42 deletions test/sub_current/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,14 @@ def setup_method(self):
self._tls = FramerTLS(self.decoder)
self._rtu = FramerRTU(self.decoder)
self._ascii = FramerAscii(self.decoder)
self._manager = SyncModbusTransactionManager(None, 3)
client = mock.MagicMock()
client.framer = self._rtu
self._manager = SyncModbusTransactionManager(client, 3)

# ----------------------------------------------------------------------- #
# Modbus transaction manager
# ----------------------------------------------------------------------- #

def test_calculate_expected_response_length(self):
"""Test calculate expected response length."""
self._manager.client = mock.MagicMock()
self._manager.client.framer = mock.MagicMock()
self._manager._set_adu_size() # pylint: disable=protected-access
assert not self._manager._calculate_response_length( # pylint: disable=protected-access
0
)
self._manager.base_adu_size = 10
assert (
self._manager._calculate_response_length(5) # pylint: disable=protected-access
== 15
)

def test_calculate_exception_length(self):
"""Test calculate exception length."""
for framer, exception_length in (
("ascii", 11),
("rtu", 5),
("tcp", 9),
("tls", 2),
("dummy", None),
):
self._manager.client = mock.MagicMock()
if framer == "ascii":
self._manager.client.framer = self._ascii
elif framer == "rtu":
self._manager.client.framer = self._rtu
elif framer == "tcp":
self._manager.client.framer = self._tcp
elif framer == "tls":
self._manager.client.framer = self._tls
else:
self._manager.client.framer = mock.MagicMock()

self._manager._set_adu_size() # pylint: disable=protected-access
assert (
self._manager._calculate_exception_length() # pylint: disable=protected-access
== exception_length
)

@mock.patch.object(SyncModbusTransactionManager, "_recv")
@mock.patch.object(ModbusTransactionManager, "getTransaction")
def test_execute(self, mock_get_transaction, mock_recv):
Expand Down

0 comments on commit 9aa8a07

Please sign in to comment.