From 1d1750ff6abf1f6fd2c72bf325a986f4d3d72256 Mon Sep 17 00:00:00 2001 From: jan iversen Date: Wed, 11 Oct 2023 09:32:36 +0200 Subject: [PATCH] Simplify transport_serial (modbus use) (#1808) --- pymodbus/transport/transport_serial.py | 82 +++++++++++--------------- test/sub_transport/test_basic.py | 6 +- test/sub_transport/test_comm.py | 56 ++++++++++++++++++ 3 files changed, 96 insertions(+), 48 deletions(-) diff --git a/pymodbus/transport/transport_serial.py b/pymodbus/transport/transport_serial.py index 09f4a7663..f57f12df4 100644 --- a/pymodbus/transport/transport_serial.py +++ b/pymodbus/transport/transport_serial.py @@ -12,6 +12,8 @@ class SerialTransport(asyncio.Transport): """An asyncio serial transport.""" + force_poll: bool = False + def __init__(self, loop, protocol, *args, **kwargs): """Initialize.""" super().__init__() @@ -19,22 +21,18 @@ def __init__(self, loop, protocol, *args, **kwargs): self._protocol: asyncio.BaseProtocol = protocol self.sync_serial = serial.serial_for_url(*args, **kwargs) self._write_buffer = [] - self._has_reader = False - self._has_writer = False + self.poll_task = None self._poll_wait_time = 0.0005 self.sync_serial.timeout = 0 self.sync_serial.write_timeout = 0 def setup(self): """Prepare to read/write""" - self.async_loop.call_soon(self._protocol.connection_made, self) - if os.name == "nt": - self._has_reader = self.async_loop.call_later( - self._poll_wait_time, self._poll_read - ) + if os.name == "nt" or self.force_poll: + self.poll_task = asyncio.create_task(self._polling_task()) else: self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready) - self._has_reader = True + self.async_loop.call_soon(self._protocol.connection_made, self) def close(self, exc=None): """Close the transport gracefully.""" @@ -43,13 +41,13 @@ def close(self, exc=None): with contextlib.suppress(Exception): self.sync_serial.flush() - if self._has_reader: - if os.name == "nt": - self._has_reader.cancel() - else: - self.async_loop.remove_reader(self.sync_serial.fileno()) - self._has_reader = False self.flush() + if self.poll_task: + self.poll_task.cancel() + _ = asyncio.ensure_future(self.poll_task) + self.poll_task = None + else: + self.async_loop.remove_reader(self.sync_serial.fileno()) self.sync_serial.close() self.sync_serial = None with contextlib.suppress(Exception): @@ -58,21 +56,13 @@ def close(self, exc=None): def write(self, data): """Write some data to the transport.""" self._write_buffer.append(data) - if not self._has_writer: - if os.name == "nt": - self._has_writer = self.async_loop.call_soon(self._poll_write) - else: - self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready) - self._has_writer = True + if not self.poll_task: + self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready) def flush(self): """Clear output buffer and stops any more data being written""" - if self._has_writer: - if os.name == "nt": - self._has_writer.cancel() - else: - self.async_loop.remove_writer(self.sync_serial.fileno()) - self._has_writer = False + if not self.poll_task: + self.async_loop.remove_writer(self.sync_serial.fileno()) self._write_buffer.clear() # ------------------------------------------------ @@ -141,34 +131,32 @@ def _write_ready(self): """Asynchronously write buffered data.""" data = b"".join(self._write_buffer) try: - if nlen := self.sync_serial.write(data) < len(data): - self._write_buffer = data[nlen:] - return True + if (nlen := self.sync_serial.write(data)) < len(data): + self._write_buffer = [data[nlen:]] + if not self.poll_task: + self.async_loop.add_writer( + self.sync_serial.fileno(), self._write_ready + ) + return self.flush() except (BlockingIOError, InterruptedError): - return True + return except serial.SerialException as exc: self.close(exc=exc) - return False - def _poll_read(self): - if self._has_reader: - try: - self._has_reader = self.async_loop.call_later( - self._poll_wait_time, self._poll_read - ) + async def _polling_task(self): + """Poll and try to read/write.""" + try: + while True: + await asyncio.sleep(self._poll_wait_time) + while self._write_buffer: + self._write_ready() if self.sync_serial.in_waiting: self._read_ready() - except serial.SerialException as exc: - self.close(exc=exc) - - def _poll_write(self): - if not self._has_writer: - return - if self._write_ready(): - self._has_writer = self.async_loop.call_later( - self._poll_wait_time, self._poll_write - ) + except serial.SerialException as exc: + self.close(exc=exc) + except asyncio.CancelledError: + pass async def create_serial_connection(loop, protocol_factory, *args, **kwargs): diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py index de30abacc..9eda7b6b7 100644 --- a/test/sub_transport/test_basic.py +++ b/test/sub_transport/test_basic.py @@ -325,6 +325,10 @@ async def test_external_methods(self): comm.close() comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy") comm.abort() - assert await create_serial_connection( + transport, protocol = await create_serial_connection( asyncio.get_running_loop(), mock.Mock, url="dummy" ) + await asyncio.sleep(0.1) + assert transport + assert protocol + transport.close() diff --git a/test/sub_transport/test_comm.py b/test/sub_transport/test_comm.py index 87cb964a3..22f67c1ba 100644 --- a/test/sub_transport/test_comm.py +++ b/test/sub_transport/test_comm.py @@ -9,6 +9,7 @@ CommType, ModbusProtocol, ) +from pymodbus.transport.transport_serial import SerialTransport FACTOR = 1.2 if not pytest.IS_WINDOWS else 4.2 @@ -125,6 +126,61 @@ async def test_connected(self, client, server, use_comm_type): assert not server.active_connections server.transport_close() + def wrapped_write(self, data): + """Wrap serial write, to split parameters.""" + return self.serial_write(data[:2]) + + @pytest.mark.parametrize( + ("use_comm_type", "use_host"), + [ + (CommType.SERIAL, "socket://localhost:5020"), + ], + ) + async def test_split_serial_packet(self, client, server): + """Test connection and data exchange.""" + assert await server.transport_listen() + assert await client.transport_connect() + await asyncio.sleep(0.5) + assert len(server.active_connections) == 1 + server_connected = list(server.active_connections.values())[0] + test_data = b"abcd" + + self.serial_write = ( # pylint: disable=attribute-defined-outside-init + client.transport.sync_serial.write + ) + with mock.patch.object( + client.transport.sync_serial, "write", wraps=self.wrapped_write + ): + client.transport_send(test_data) + await asyncio.sleep(0.5) + assert server_connected.recv_buffer == test_data + assert not client.recv_buffer + client.transport_close() + server.transport_close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_host"), + [ + (CommType.SERIAL, "socket://localhost:5020"), + ], + ) + async def test_serial_poll(self, client, server): + """Test connection and data exchange.""" + assert await server.transport_listen() + SerialTransport.force_poll = True + assert await client.transport_connect() + await asyncio.sleep(0.5) + SerialTransport.force_poll = False + assert len(server.active_connections) == 1 + server_connected = list(server.active_connections.values())[0] + test_data = b"abcd" * 1000 + client.transport_send(test_data) + await asyncio.sleep(0.5) + assert server_connected.recv_buffer == test_data + assert not client.recv_buffer + client.transport_close() + server.transport_close() + @pytest.mark.parametrize( ("use_comm_type", "use_host"), [