From 5b1c1951f9e573fbe3548e2504c43a095a58f7bf Mon Sep 17 00:00:00 2001 From: jan iversen Date: Thu, 28 Mar 2024 21:18:05 +0100 Subject: [PATCH] Streamline message class. (#2133) --- pymodbus/framer/ascii_framer.py | 2 +- pymodbus/framer/rtu_framer.py | 2 +- pymodbus/framer/socket_framer.py | 2 +- pymodbus/framer/tls_framer.py | 2 +- pymodbus/message/base.py | 18 ++---------------- pymodbus/message/message.py | 18 ++++++++++++------ pymodbus/message/raw.py | 3 --- pymodbus/message/rtu.py | 2 +- test/message/conftest.py | 1 + test/message/test_ascii.py | 2 +- test/message/test_message.py | 17 ++++++++--------- test/message/test_rtu.py | 2 +- test/message/test_socket.py | 2 +- test/message/test_tls.py | 2 +- 14 files changed, 32 insertions(+), 43 deletions(-) diff --git a/pymodbus/framer/ascii_framer.py b/pymodbus/framer/ascii_framer.py index 8ce7a67b4..4d2fbc68a 100644 --- a/pymodbus/framer/ascii_framer.py +++ b/pymodbus/framer/ascii_framer.py @@ -39,7 +39,7 @@ def __init__(self, decoder, client=None): self._hsize = 0x02 self._start = b":" self._end = b"\r\n" - self.message_handler = MessageAscii([0], True) + self.message_handler = MessageAscii() def decode_data(self, data): """Decode data.""" diff --git a/pymodbus/framer/rtu_framer.py b/pymodbus/framer/rtu_framer.py index dd6ec817c..6455a28f8 100644 --- a/pymodbus/framer/rtu_framer.py +++ b/pymodbus/framer/rtu_framer.py @@ -61,7 +61,7 @@ def __init__(self, decoder, client=None): self._end = b"\x0d\x0a" self._min_frame_size = 4 self.function_codes = decoder.lookup.keys() if decoder else {} - self.message_handler = MessageRTU([0], True) + self.message_handler = MessageRTU() def decode_data(self, data): """Decode data.""" diff --git a/pymodbus/framer/socket_framer.py b/pymodbus/framer/socket_framer.py index 8460c8a15..582aa283a 100644 --- a/pymodbus/framer/socket_framer.py +++ b/pymodbus/framer/socket_framer.py @@ -43,7 +43,7 @@ def __init__(self, decoder, client=None): """ super().__init__(decoder, client) self._hsize = 0x07 - self.message_handler = MessageSocket([0], True) + self.message_handler = MessageSocket() def decode_data(self, data): """Decode data.""" diff --git a/pymodbus/framer/tls_framer.py b/pymodbus/framer/tls_framer.py index a39de0321..1b341b224 100644 --- a/pymodbus/framer/tls_framer.py +++ b/pymodbus/framer/tls_framer.py @@ -34,7 +34,7 @@ def __init__(self, decoder, client=None): """ super().__init__(decoder, client) self._hsize = 0x0 - self.message_handler = MessageTLS([0], True) + self.message_handler = MessageTLS() def decode_data(self, data): """Decode data.""" diff --git a/pymodbus/message/base.py b/pymodbus/message/base.py index 831cb57b2..3a41bb9eb 100644 --- a/pymodbus/message/base.py +++ b/pymodbus/message/base.py @@ -14,23 +14,9 @@ class MessageBase: EMPTY = b'' - def __init__( - self, - device_ids: list[int], - is_server: bool, - ) -> None: - """Initialize a message instance. - - :param device_ids: list of device id to accept (server only), None for all. - """ - self.device_ids = device_ids - self.is_server = is_server - self.broadcast: bool = (0 in device_ids) - + def __init__(self) -> None: + """Initialize a message instance.""" - def validate_device_id(self, dev_id: int) -> bool: - """Check if device id is expected.""" - return self.broadcast or (dev_id in self.device_ids) @abstractmethod def decode(self, _data: bytes) -> tuple[int, int, int, bytes]: diff --git a/pymodbus/message/message.py b/pymodbus/message/message.py index 0f1967843..20427bd23 100644 --- a/pymodbus/message/message.py +++ b/pymodbus/message/message.py @@ -65,15 +65,21 @@ def __init__(self, """ super().__init__(params, is_server) self.device_ids = device_ids - self.message_type = message_type + self.broadcast: bool = (0 in device_ids) self.msg_handle: MessageBase = { - MessageType.RAW: MessageRaw(device_ids, is_server), - MessageType.ASCII: MessageAscii(device_ids, is_server), - MessageType.RTU: MessageRTU(device_ids, is_server), - MessageType.SOCKET: MessageSocket(device_ids, is_server), - MessageType.TLS: MessageTLS(device_ids, is_server), + MessageType.RAW: MessageRaw(), + MessageType.ASCII: MessageAscii(), + MessageType.RTU: MessageRTU(), + MessageType.SOCKET: MessageSocket(), + MessageType.TLS: MessageTLS(), }[message_type] + + def validate_device_id(self, dev_id: int) -> bool: + """Check if device id is expected.""" + return self.broadcast or (dev_id in self.device_ids) + + def callback_data(self, data: bytes, addr: tuple | None = None) -> int: """Handle received data.""" tot_len = len(data) diff --git a/pymodbus/message/raw.py b/pymodbus/message/raw.py index e75324556..88627482b 100644 --- a/pymodbus/message/raw.py +++ b/pymodbus/message/raw.py @@ -23,9 +23,6 @@ def decode(self, data: bytes) -> tuple[int, int, int, bytes]: return 0, 0, 0, self.EMPTY dev_id = int(data[0]) tid = int(data[1]) - if not self.validate_device_id(dev_id): - Log.debug("Device id: {} in frame {} unknown, skipping.", dev_id, data, ":hex") - return len(data), dev_id, tid, data[2:] def encode(self, data: bytes, device_id: int, tid: int) -> bytes: diff --git a/pymodbus/message/rtu.py b/pymodbus/message/rtu.py index fa5d7dad3..2565fa5a9 100644 --- a/pymodbus/message/rtu.py +++ b/pymodbus/message/rtu.py @@ -177,7 +177,7 @@ def callback(result): nonlocal resp resp = result - self._legacy_decode(callback, self.device_ids) + self._legacy_decode(callback, [0]) return 0, 0, 0, b'' diff --git a/test/message/conftest.py b/test/message/conftest.py index cd5345d01..dcf2e3369 100644 --- a/test/message/conftest.py +++ b/test/message/conftest.py @@ -21,6 +21,7 @@ def __init__(self, """Initialize a message instance.""" super().__init__(message_type, params, is_server, device_ids) self.send = mock.Mock() + self.message_type = message_type def callback_new_connection(self) -> ModbusProtocol: """Call when listener receive new connection request.""" diff --git a/test/message/test_ascii.py b/test/message/test_ascii.py index 1235e646e..e166828c3 100644 --- a/test/message/test_ascii.py +++ b/test/message/test_ascii.py @@ -11,7 +11,7 @@ class TestMessageAscii: @pytest.fixture(name="frame") def prepare_frame(): """Return message object.""" - return MessageAscii([1], False) + return MessageAscii() @pytest.mark.parametrize( diff --git a/test/message/test_message.py b/test/message/test_message.py index 22542ece1..113ec8f8c 100644 --- a/test/message/test_message.py +++ b/test/message/test_message.py @@ -74,7 +74,7 @@ async def test_message_build_send(self, msg): ]) async def test_validate_id(self, msg, dev_id, res): """Test message type.""" - assert res == msg.msg_handle.validate_device_id(dev_id) + assert res == msg.validate_device_id(dev_id) @pytest.mark.parametrize( ("data", "res_len", "res_id", "res_tid", "res_data"), [ @@ -159,7 +159,7 @@ class TestMessages: b'\x11\x03\x00\x7c\x00\x02\x07\x43', b'\x11\x03\x04\x00\x8d\x00\x8e\xfb\xbd', b'\x11\x83\x02\xc1\x34', - b'\xff\x03\x00|\x00\x02\x10\x0d', + b'\xff\x03\x00\x7c\x00\x02\x10\x0d', b'\xff\x03\x04\x00\x8d\x00\x8e\xf5\xb3', b'\xff\x83\x02\xa1\x01', ]), @@ -215,11 +215,10 @@ class TestMessages: ) def test_encode(self, frame, frame_expected, data, dev_id, tid, inx1, inx2, inx3): """Test encode method.""" - if frame != MessageSocket and tid: - pytest.skip("Not supported") - if frame == MessageTLS and (tid or dev_id): - pytest.skip("Not supported") - frame_obj = frame([0], True) + if ((frame != MessageSocket and tid) or + (frame == MessageTLS and dev_id)): + return + frame_obj = frame() expected = frame_expected[inx1 + inx2 + inx3] encoded_data = frame_obj.encode(data, dev_id, tid) assert encoded_data == expected @@ -281,7 +280,7 @@ async def test_decode(self, dummy_message, msg_type, data, dev_id, tid, expected if msg_type == MessageType.RTU: pytest.skip("Waiting on implementation!") if msg_type == MessageType.TLS and split != "no": - pytest.skip("Not supported.") + return frame = dummy_message( msg_type, CommParams(), @@ -323,7 +322,7 @@ async def test_decode_bad_crc(self, frame, data, exp_len): """Test encode method.""" if frame == MessageRTU: pytest.skip("Waiting for implementation.") - frame_obj = frame([0], True) + frame_obj = frame() used_len, _, _, data = frame_obj.decode(data) assert used_len == exp_len assert not data diff --git a/test/message/test_rtu.py b/test/message/test_rtu.py index 6892ab368..e7c019b99 100644 --- a/test/message/test_rtu.py +++ b/test/message/test_rtu.py @@ -11,7 +11,7 @@ class TestMessageRTU: @pytest.fixture(name="frame") def prepare_frame(): """Return message object.""" - return MessageRTU([1], False) + return MessageRTU() @pytest.mark.parametrize( diff --git a/test/message/test_socket.py b/test/message/test_socket.py index 1c4e930ce..e58ba65cc 100644 --- a/test/message/test_socket.py +++ b/test/message/test_socket.py @@ -12,7 +12,7 @@ class TestMessageSocket: @pytest.fixture(name="frame") def prepare_frame(): """Return message object.""" - return MessageSocket([1], False) + return MessageSocket() @pytest.mark.parametrize( diff --git a/test/message/test_tls.py b/test/message/test_tls.py index d9140c2cb..194fda459 100644 --- a/test/message/test_tls.py +++ b/test/message/test_tls.py @@ -12,7 +12,7 @@ class TestMessageSocket: @pytest.fixture(name="frame") def prepare_frame(): """Return message object.""" - return MessageTLS([1], False) + return MessageTLS() @pytest.mark.parametrize(