Skip to content

Commit

Permalink
update modbusrtuframer (#1435)
Browse files Browse the repository at this point in the history
Co-authored-by: dlmoffett
  • Loading branch information
janiversen authored Mar 22, 2023
1 parent 1cedc47 commit 99fd276
Show file tree
Hide file tree
Showing 8 changed files with 549 additions and 48 deletions.
22 changes: 11 additions & 11 deletions pymodbus/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def getFCdict(cls):
def __init__(self):
"""Initialize the client lookup tables."""
functions = {f.function_code for f in self.__function_table}
self.__lookup = self.getFCdict()
self.lookup = self.getFCdict()
self.__sub_lookup = {f: {} for f in functions}
for f in self.__sub_function_table:
self.__sub_lookup[f.function_code][f.sub_function_code] = f
Expand All @@ -187,7 +187,7 @@ def lookupPduClass(self, function_code):
:param function_code: The function code specified in a frame.
:returns: The class of the PDU that has a matching `function_code`.
"""
return self.__lookup.get(function_code, ExceptionResponse)
return self.lookup.get(function_code, ExceptionResponse)

def _helper(self, data):
"""Generate the correct request object from a valid request packet.
Expand All @@ -198,12 +198,12 @@ def _helper(self, data):
:returns: The decoded request or illegal function request object
"""
function_code = int(data[0])
if not (request := self.__lookup.get(function_code, lambda: None)()):
if not (request := self.lookup.get(function_code, lambda: None)()):
Log.debug("Factory Request[{}]", function_code)
request = IllegalFunctionRequest(function_code)
else:
fc_string = "%s: %s" % ( # pylint: disable=consider-using-f-string
str(self.__lookup[function_code]) # pylint: disable=use-maxsplit-arg
str(self.lookup[function_code]) # pylint: disable=use-maxsplit-arg
.split(".")[-1]
.rstrip('">"'),
function_code,
Expand All @@ -230,7 +230,7 @@ def register(self, function=None):
". Class needs to be derived from "
"`pymodbus.pdu.ModbusRequest` "
)
self.__lookup[function.function_code] = function
self.lookup[function.function_code] = function
if hasattr(function, "sub_function_code"):
if function.function_code not in self.__sub_lookup:
self.__sub_lookup[function.function_code] = {}
Expand Down Expand Up @@ -293,7 +293,7 @@ class ClientDecoder:
def __init__(self):
"""Initialize the client lookup tables."""
functions = {f.function_code for f in self.function_table}
self.__lookup = {f.function_code: f for f in self.function_table}
self.lookup = {f.function_code: f for f in self.function_table}
self.__sub_lookup = {f: {} for f in functions}
for f in self.__sub_function_table:
self.__sub_lookup[f.function_code][f.sub_function_code] = f
Expand All @@ -304,7 +304,7 @@ def lookupPduClass(self, function_code):
:param function_code: The function code specified in a frame.
:returns: The class of the PDU that has a matching `function_code`.
"""
return self.__lookup.get(function_code, ExceptionResponse)
return self.lookup.get(function_code, ExceptionResponse)

def decode(self, message):
"""Decode a response packet.
Expand All @@ -330,15 +330,15 @@ def _helper(self, data):
:raises ModbusException:
"""
fc_string = function_code = int(data[0])
if function_code in self.__lookup:
if function_code in self.lookup:
fc_string = "%s: %s" % ( # pylint: disable=consider-using-f-string
str(self.__lookup[function_code]) # pylint: disable=use-maxsplit-arg
str(self.lookup[function_code]) # pylint: disable=use-maxsplit-arg
.split(".")[-1]
.rstrip('">"'),
function_code,
)
Log.debug("Factory Response[{}]", fc_string)
response = self.__lookup.get(function_code, lambda: None)()
response = self.lookup.get(function_code, lambda: None)()
if function_code > 0x80:
code = function_code & 0x7F # strip error portion
response = ExceptionResponse(code, ecode.IllegalFunction)
Expand All @@ -361,7 +361,7 @@ def register(self, function):
". Class needs to be derived from "
"`pymodbus.pdu.ModbusResponse` "
)
self.__lookup[function.function_code] = function
self.lookup[function.function_code] = function
if hasattr(function, "sub_function_code"):
if function.function_code not in self.__sub_lookup:
self.__sub_lookup[function.function_code] = {}
Expand Down
51 changes: 35 additions & 16 deletions pymodbus/framer/rtu_framer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self, decoder, client=None):
self._hsize = 0x01
self._end = b"\x0d\x0a"
self._min_frame_size = 4
self.function_codes = set(self.decoder.lookup) if self.decoder else {}

# ----------------------------------------------------------------------- #
# Private Helper Functions
Expand Down Expand Up @@ -117,7 +118,7 @@ def resetFrame(self):
Log.debug(
"Resetting frame - Current Frame in buffer - {}", self._buffer, ":hex"
)
self._buffer = b""
# self._buffer = b""
self._header = {"uid": 0x00, "len": 0, "crc": b"\x00\x00"}

def isFrameReady(self):
Expand Down Expand Up @@ -191,6 +192,23 @@ def populateResult(self, result):
result.slave_id = self._header["uid"]
result.transaction_id = self._header["uid"]

def getFrameStart(self, slaves, broadcast, skip_cur_frame):
"""Scan buffer for a relevant frame start."""
start = 1 if skip_cur_frame else 0
if (buf_len := len(self._buffer)) < 4:
return False
for i in range(start, buf_len - 3): # <slave id><function code><crc 2 bytes>
if not broadcast and self._buffer[i] not in slaves:
continue
if self._buffer[i + 1] not in self.function_codes:
continue
if i:
self._buffer = self._buffer[i:] # remove preceding trash.
return True
if buf_len > 3:
self._buffer = self._buffer[-3:]
return False

# ----------------------------------------------------------------------- #
# Public Member Functions
# ----------------------------------------------------------------------- #
Expand All @@ -214,25 +232,26 @@ def processIncomingPacket(self, data, callback, slave, **kwargs):
"""
if not isinstance(slave, (list, tuple)):
slave = [slave]
broadcast = not slave[0]
self.addToFrame(data)
single = kwargs.get("single", False)
while True:
if self.isFrameReady():
if self.checkFrame():
if self._validate_slave_id(slave, single):
self._process(callback)
else:
header_txt = self._header["uid"]
Log.debug("Not a valid slave id - {}, ignoring!!", header_txt)
self.resetFrame()
break
else:
Log.debug("Frame check failed, ignoring!!")
self.resetFrame()
break
else:
skip_cur_frame = False
while self.getFrameStart(slave, broadcast, skip_cur_frame):
if not self.isFrameReady():
Log.debug("Frame - [{}] not ready", data)
break
if not self.checkFrame():
Log.debug("Frame check failed, ignoring!!")
self.resetFrame()
skip_cur_frame = True
continue
if not self._validate_slave_id(slave, single):
header_txt = self._header["uid"]
Log.debug("Not a valid slave id - {}, ignoring!!", header_txt)
self.resetFrame()
skip_cur_frame = True
continue
self._process(callback)

def buildPacket(self, message):
"""Create a ready to send modbus packet.
Expand Down
87 changes: 68 additions & 19 deletions test/test_client_multidrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@ def test_ok_frame(self, framer, callback):
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

def test_ok_2frame(self, framer, callback):
"""Test ok frame."""
serial_event = self.good_frame + self.good_frame
framer.processIncomingPacket(serial_event, callback, self.slaves)
assert callback.call_count == 2

def test_bad_crc(self, framer, callback):
"""Test bad crc."""
serial_event = b"\x02\x03\x00\x01\x00}\xd4\x19" # Manually mangled crc
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_not_called()

def test_wrong_unit(self, framer, callback):
"""Test frame wrong unit"""
serial_event = (
b"\x01\x03\x00\x01\x00}\xd4+" # Frame with good CRC but other unit id
)
def test_wrong_id(self, framer, callback):
"""Test frame wrong id"""
serial_event = b"\x01\x03\x00\x01\x00}\xd4+" # Frame with good CRC but other id
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_not_called()

def test_big_split_response_frame_from_other_unit(self, framer, callback):
def test_big_split_response_frame_from_other_id(self, framer, callback):
"""Test split response."""
# This is a single *response* from device id 1 after being queried for 125 holding register values
# Because the response is so long it spans several serial events
Expand All @@ -70,33 +74,52 @@ def test_split_frame(self, framer, callback):
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

@pytest.mark.skip
def test_complete_frame_trailing_data_without_unit_id(self, framer, callback):
def test_complete_frame_trailing_data_without_id(self, framer, callback):
"""Test trailing data."""
garbage = b"\x05\x04\x03" # Note the garbage doesn't contain our unit id
garbage = b"\x05\x04\x03" # without id
serial_event = garbage + self.good_frame
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

@pytest.mark.skip
def test_complete_frame_trailing_data_with_unit_id(self, framer, callback):
def test_complete_frame_trailing_data_with_id(self, framer, callback):
"""Test trailing data."""
garbage = (
b"\x05\x04\x03\x02\x01\x00" # Note the garbage does contain our unit id
)
garbage = b"\x05\x04\x03\x02\x01\x00" # with id
serial_event = garbage + self.good_frame
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

@pytest.mark.skip
def test_split_frame_trailing_data_with_unit_id(self, framer, callback):
def test_split_frame_trailing_data_with_id(self, framer, callback):
"""Test split frame."""
garbage = b"\x05\x04\x03\x02\x01\x00"
serial_events = [garbage + self.good_frame[:5], self.good_frame[5:]]
for serial_event in serial_events:
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

def test_coincidental_1(self, framer, callback):
"""Test conincidental."""
garbage = b"\x02\x90\x07"
serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
for serial_event in serial_events:
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

def test_coincidental_2(self, framer, callback):
"""Test conincidental."""
garbage = b"\x02\x10\x07"
serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
for serial_event in serial_events:
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

def test_coincidental_3(self, framer, callback):
"""Test conincidental."""
garbage = b"\x02\x10\x07\x10"
serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
for serial_event in serial_events:
framer.processIncomingPacket(serial_event, callback, self.slaves)
callback.assert_called_once()

def test_wrapped_frame(self, framer, callback):
"""Test wrapped frame."""
garbage = b"\x05\x04\x03\x02\x01\x00"
Expand All @@ -107,14 +130,40 @@ def test_wrapped_frame(self, framer, callback):
# i.e. this probably represents a case where a command came for us, but we didn't get
# to the serial buffer in time (some other co-routine or perhaps a block on the USB bus)
# and the master moved on and queried another device
callback.assert_not_called()
callback.assert_called_once()

@pytest.mark.skip
def test_frame_with_trailing_data(self, framer, callback):
"""Test trailing data."""
garbage = b"\x05\x04\x03\x02\x01\x00"
serial_event = self.good_frame + garbage
framer.processIncomingPacket(serial_event, callback, self.slaves)

# We should not respond in this case for identical reasons as test_wrapped_frame
callback.assert_not_called()
callback.assert_called_once()

def test_getFrameStart(self, framer):
"""Test getFrameStart."""
framer_ok = b"\x02\x03\x00\x01\x00}\xd4\x18"
framer._buffer = framer_ok # pylint: disable=protected-access
assert framer.getFrameStart(self.slaves, False, False)
assert framer_ok == framer._buffer # pylint: disable=protected-access

framer_2ok = framer_ok + framer_ok
framer._buffer = framer_2ok # pylint: disable=protected-access
assert framer.getFrameStart(self.slaves, False, False)
assert framer_2ok == framer._buffer # pylint: disable=protected-access
assert framer.getFrameStart(self.slaves, False, True)
assert framer_ok == framer._buffer # pylint: disable=protected-access

framer._buffer = framer_ok[:2] # pylint: disable=protected-access
assert not framer.getFrameStart(self.slaves, False, False)
assert framer_ok[:2] == framer._buffer # pylint: disable=protected-access

framer._buffer = framer_ok[:3] # pylint: disable=protected-access
assert not framer.getFrameStart(self.slaves, False, False)
assert framer_ok[:3] == framer._buffer # pylint: disable=protected-access

framer_ok = b"\xF0\x03\x00\x01\x00}\xd4\x18"
framer._buffer = framer_ok # pylint: disable=protected-access
assert not framer.getFrameStart(self.slaves, False, False)
assert framer._buffer == framer_ok[-3:] # pylint: disable=protected-access
3 changes: 1 addition & 2 deletions test/test_framers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def test_rtu_reset_framer(rtu_framer, data): # pylint: disable=redefined-outer-
"len": 0,
"crc": b"\x00\x00",
}
assert rtu_framer._buffer == b"" # pylint: disable=protected-access


@pytest.mark.parametrize(
Expand Down Expand Up @@ -255,7 +254,7 @@ def test_populate_result(rtu_framer): # pylint: disable=redefined-outer-name
(
b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAD",
16,
True,
False,
False,
), # incorrect slave id
(b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAD\x11\x03", 17, False, True),
Expand Down
Loading

0 comments on commit 99fd276

Please sign in to comment.