Skip to content

Commit

Permalink
Streamline message class. (#2133)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Jun 18, 2024
1 parent 059a1b6 commit 66d3d64
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 43 deletions.
2 changes: 1 addition & 1 deletion pymodbus/framer/ascii_framer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, decoder, client=None):
self._hsize = 0x02
self._start = b":"
self._end = b"\r\n"
self.message_handler = MessageAscii([0], True)
self.message_handler = MessageAscii()

def decode_data(self, data):
"""Decode data."""
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/framer/rtu_framer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, decoder, client=None):
self._end = b"\x0d\x0a"
self._min_frame_size = 4
self.function_codes = decoder.lookup.keys() if decoder else {}
self.message_handler = MessageRTU([0], True)
self.message_handler = MessageRTU()

def decode_data(self, data):
"""Decode data."""
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/framer/socket_framer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, decoder, client=None):
"""
super().__init__(decoder, client)
self._hsize = 0x07
self.message_handler = MessageSocket([0], True)
self.message_handler = MessageSocket()

def decode_data(self, data):
"""Decode data."""
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/framer/tls_framer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, decoder, client=None):
"""
super().__init__(decoder, client)
self._hsize = 0x0
self.message_handler = MessageTLS([0], True)
self.message_handler = MessageTLS()

def decode_data(self, data):
"""Decode data."""
Expand Down
18 changes: 2 additions & 16 deletions pymodbus/message/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,9 @@ class MessageBase:

EMPTY = b''

def __init__(
self,
device_ids: list[int],
is_server: bool,
) -> None:
"""Initialize a message instance.
:param device_ids: list of device id to accept (server only), None for all.
"""
self.device_ids = device_ids
self.is_server = is_server
self.broadcast: bool = (0 in device_ids)

def __init__(self) -> None:
"""Initialize a message instance."""

def validate_device_id(self, dev_id: int) -> bool:
"""Check if device id is expected."""
return self.broadcast or (dev_id in self.device_ids)

@abstractmethod
def decode(self, _data: bytes) -> tuple[int, int, int, bytes]:
Expand Down
18 changes: 12 additions & 6 deletions pymodbus/message/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,21 @@ def __init__(self,
"""
super().__init__(params, is_server)
self.device_ids = device_ids
self.message_type = message_type
self.broadcast: bool = (0 in device_ids)
self.msg_handle: MessageBase = {
MessageType.RAW: MessageRaw(device_ids, is_server),
MessageType.ASCII: MessageAscii(device_ids, is_server),
MessageType.RTU: MessageRTU(device_ids, is_server),
MessageType.SOCKET: MessageSocket(device_ids, is_server),
MessageType.TLS: MessageTLS(device_ids, is_server),
MessageType.RAW: MessageRaw(),
MessageType.ASCII: MessageAscii(),
MessageType.RTU: MessageRTU(),
MessageType.SOCKET: MessageSocket(),
MessageType.TLS: MessageTLS(),
}[message_type]


def validate_device_id(self, dev_id: int) -> bool:
"""Check if device id is expected."""
return self.broadcast or (dev_id in self.device_ids)


def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
tot_len = len(data)
Expand Down
3 changes: 0 additions & 3 deletions pymodbus/message/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def decode(self, data: bytes) -> tuple[int, int, int, bytes]:
return 0, 0, 0, self.EMPTY
dev_id = int(data[0])
tid = int(data[1])
if not self.validate_device_id(dev_id):
Log.debug("Device id: {} in frame {} unknown, skipping.", dev_id, data, ":hex")

return len(data), dev_id, tid, data[2:]

def encode(self, data: bytes, device_id: int, tid: int) -> bytes:
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/message/rtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def callback(result):
nonlocal resp
resp = result

self._legacy_decode(callback, self.device_ids)
self._legacy_decode(callback, [0])
return 0, 0, 0, b''


Expand Down
1 change: 1 addition & 0 deletions test/message/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(self,
"""Initialize a message instance."""
super().__init__(message_type, params, is_server, device_ids)
self.send = mock.Mock()
self.message_type = message_type

def callback_new_connection(self) -> ModbusProtocol:
"""Call when listener receive new connection request."""
Expand Down
2 changes: 1 addition & 1 deletion test/message/test_ascii.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class TestMessageAscii:
@pytest.fixture(name="frame")
def prepare_frame():
"""Return message object."""
return MessageAscii([1], False)
return MessageAscii()


@pytest.mark.parametrize(
Expand Down
17 changes: 8 additions & 9 deletions test/message/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def test_message_build_send(self, msg):
])
async def test_validate_id(self, msg, dev_id, res):
"""Test message type."""
assert res == msg.msg_handle.validate_device_id(dev_id)
assert res == msg.validate_device_id(dev_id)

@pytest.mark.parametrize(
("data", "res_len", "res_id", "res_tid", "res_data"), [
Expand Down Expand Up @@ -159,7 +159,7 @@ class TestMessages:
b'\x11\x03\x00\x7c\x00\x02\x07\x43',
b'\x11\x03\x04\x00\x8d\x00\x8e\xfb\xbd',
b'\x11\x83\x02\xc1\x34',
b'\xff\x03\x00|\x00\x02\x10\x0d',
b'\xff\x03\x00\x7c\x00\x02\x10\x0d',
b'\xff\x03\x04\x00\x8d\x00\x8e\xf5\xb3',
b'\xff\x83\x02\xa1\x01',
]),
Expand Down Expand Up @@ -215,11 +215,10 @@ class TestMessages:
)
def test_encode(self, frame, frame_expected, data, dev_id, tid, inx1, inx2, inx3):
"""Test encode method."""
if frame != MessageSocket and tid:
pytest.skip("Not supported")
if frame == MessageTLS and (tid or dev_id):
pytest.skip("Not supported")
frame_obj = frame([0], True)
if ((frame != MessageSocket and tid) or
(frame == MessageTLS and dev_id)):
return
frame_obj = frame()
expected = frame_expected[inx1 + inx2 + inx3]
encoded_data = frame_obj.encode(data, dev_id, tid)
assert encoded_data == expected
Expand Down Expand Up @@ -281,7 +280,7 @@ async def test_decode(self, dummy_message, msg_type, data, dev_id, tid, expected
if msg_type == MessageType.RTU:
pytest.skip("Waiting on implementation!")
if msg_type == MessageType.TLS and split != "no":
pytest.skip("Not supported.")
return
frame = dummy_message(
msg_type,
CommParams(),
Expand Down Expand Up @@ -323,7 +322,7 @@ async def test_decode_bad_crc(self, frame, data, exp_len):
"""Test encode method."""
if frame == MessageRTU:
pytest.skip("Waiting for implementation.")
frame_obj = frame([0], True)
frame_obj = frame()
used_len, _, _, data = frame_obj.decode(data)
assert used_len == exp_len
assert not data
2 changes: 1 addition & 1 deletion test/message/test_rtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class TestMessageRTU:
@pytest.fixture(name="frame")
def prepare_frame():
"""Return message object."""
return MessageRTU([1], False)
return MessageRTU()


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/message/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestMessageSocket:
@pytest.fixture(name="frame")
def prepare_frame():
"""Return message object."""
return MessageSocket([1], False)
return MessageSocket()


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion test/message/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestMessageSocket:
@pytest.fixture(name="frame")
def prepare_frame():
"""Return message object."""
return MessageTLS([1], False)
return MessageTLS()


@pytest.mark.parametrize(
Expand Down

0 comments on commit 66d3d64

Please sign in to comment.