From 39177d72f130cb4d72138a0af41d677497cdb217 Mon Sep 17 00:00:00 2001
From: jan iversen <jancasacondor@gmail.com>
Date: Tue, 10 Oct 2023 09:59:10 +0200
Subject: [PATCH] Reduce transport_serial (#1807)

---
 pymodbus/transport/transport_serial.py | 219 +++++++++----------------
 test/sub_transport/test_basic.py       |   1 +
 2 files changed, 83 insertions(+), 137 deletions(-)

diff --git a/pymodbus/transport/transport_serial.py b/pymodbus/transport/transport_serial.py
index f5b28d5e3..09f4a7663 100644
--- a/pymodbus/transport/transport_serial.py
+++ b/pymodbus/transport/transport_serial.py
@@ -18,18 +18,62 @@ def __init__(self, loop, protocol, *args, **kwargs):
         self.async_loop = loop
         self._protocol: asyncio.BaseProtocol = protocol
         self.sync_serial = serial.serial_for_url(*args, **kwargs)
-        self._closing = False
         self._write_buffer = []
-        self.set_write_buffer_limits()
         self._has_reader = False
         self._has_writer = False
         self._poll_wait_time = 0.0005
-
-        # Asynchronous I/O requires non-blocking devices
         self.sync_serial.timeout = 0
         self.sync_serial.write_timeout = 0
-        loop.call_soon(protocol.connection_made, self)
-        loop.call_soon(self._ensure_reader)
+
+    def setup(self):
+        """Prepare to read/write"""
+        self.async_loop.call_soon(self._protocol.connection_made, self)
+        if os.name == "nt":
+            self._has_reader = self.async_loop.call_later(
+                self._poll_wait_time, self._poll_read
+            )
+        else:
+            self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
+            self._has_reader = True
+
+    def close(self, exc=None):
+        """Close the transport gracefully."""
+        if not self.sync_serial:
+            return
+        with contextlib.suppress(Exception):
+            self.sync_serial.flush()
+
+        if self._has_reader:
+            if os.name == "nt":
+                self._has_reader.cancel()
+            else:
+                self.async_loop.remove_reader(self.sync_serial.fileno())
+            self._has_reader = False
+        self.flush()
+        self.sync_serial.close()
+        self.sync_serial = None
+        with contextlib.suppress(Exception):
+            self._protocol.connection_lost(exc)
+
+    def write(self, data):
+        """Write some data to the transport."""
+        self._write_buffer.append(data)
+        if not self._has_writer:
+            if os.name == "nt":
+                self._has_writer = self.async_loop.call_soon(self._poll_write)
+            else:
+                self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
+                self._has_writer = True
+
+    def flush(self):
+        """Clear output buffer and stops any more data being written"""
+        if self._has_writer:
+            if os.name == "nt":
+                self._has_writer.cancel()
+            else:
+                self.async_loop.remove_writer(self.sync_serial.fileno())
+            self._has_writer = False
+        self._write_buffer.clear()
 
     # ------------------------------------------------
     # Dummy methods needed to please asyncio.Transport.
@@ -75,160 +119,61 @@ def pause_reading(self):
     def resume_reading(self):
         """Resume receiver."""
 
-    # ------------------------------------------------
-
     def is_closing(self):
         """Return True if the transport is closing or closed."""
-        return self._closing
+        return False
 
-    def close(self):
-        """Close the transport gracefully."""
-        if self._closing:
-            return
-        self._closing = True
-        self._remove_reader()
-        self._remove_writer()
-        self.async_loop.call_soon(self._call_connection_lost, None)
+    def abort(self):
+        """Close the transport immediately."""
+        self.close()
+
+    # ------------------------------------------------
 
     def _read_ready(self):
         """Test if there are data waiting."""
         try:
-            data = self.sync_serial.read(1024)
-        except serial.SerialException as exc:
-            self.async_loop.call_soon(self._call_connection_lost, exc)
-            self.close()
-        else:
-            if data:
+            if data := self.sync_serial.read(1024):
                 self._protocol.data_received(data)
-
-    def write(self, data):
-        """Write some data to the transport."""
-        if self._closing:
-            return
-
-        self._write_buffer.append(data)
-        self._ensure_writer()
-
-    def abort(self):
-        """Close the transport immediately."""
-        self.close()
-
-    def flush(self):
-        """Clear output buffer and stops any more data being written"""
-        self._remove_writer()
-        self._write_buffer.clear()
+        except serial.SerialException as exc:
+            self.close(exc=exc)
 
     def _write_ready(self):
         """Asynchronously write buffered data."""
         data = b"".join(self._write_buffer)
-        assert data, "Write buffer should not be empty"
-
-        self._write_buffer.clear()
-
         try:
-            nlen = self.sync_serial.write(data)
+            if nlen := self.sync_serial.write(data) < len(data):
+                self._write_buffer = data[nlen:]
+                return True
+            self.flush()
         except (BlockingIOError, InterruptedError):
-            self._write_buffer.append(data)
+            return True
         except serial.SerialException as exc:
-            self.async_loop.call_soon(self._call_connection_lost, exc)
-            self.abort()
-        else:
-            if nlen == len(data):
-                assert not self.get_write_buffer_size()
-                self._remove_writer()
-                if self._closing and not self.get_write_buffer_size():
-                    self.close()
-                return
-
-            assert 0 <= nlen < len(data)
-            data = data[nlen:]
-            self._write_buffer.append(data)  # Try again later
-            assert self._has_writer
-
-    if os.name == "nt":
-
-        def _poll_read(self):
-            if self._has_reader and not self._closing:
-                try:
-                    self._has_reader = self.async_loop.call_later(
-                        self._poll_wait_time, self._poll_read
-                    )
-                    if self.sync_serial.in_waiting:
-                        self._read_ready()
-                except serial.SerialException as exc:
-                    self.async_loop.call_soon(self._call_connection_lost, exc)
-                    self.abort()
-
-        def _ensure_reader(self):
-            if not self._has_reader and not self._closing:
+            self.close(exc=exc)
+        return False
+
+    def _poll_read(self):
+        if self._has_reader:
+            try:
                 self._has_reader = self.async_loop.call_later(
                     self._poll_wait_time, self._poll_read
                 )
+                if self.sync_serial.in_waiting:
+                    self._read_ready()
+            except serial.SerialException as exc:
+                self.close(exc=exc)
 
-        def _remove_reader(self):
-            if self._has_reader:
-                self._has_reader.cancel()
-            self._has_reader = False
-
-        def _poll_write(self):
-            if self._has_writer and not self._closing:
-                self._has_writer = self.async_loop.call_later(
-                    self._poll_wait_time, self._poll_write
-                )
-                self._write_ready()
-
-        def _ensure_writer(self):
-            if not self._has_writer and not self._closing:
-                self._has_writer = self.async_loop.call_soon(self._poll_write)
-
-        def _remove_writer(self):
-            if self._has_writer:
-                self._has_writer.cancel()
-            self._has_writer = False
-
-    else:
-
-        def _ensure_reader(self):
-            if (not self._has_reader) and (not self._closing):
-                self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
-                self._has_reader = True
-
-        def _remove_reader(self):
-            if self._has_reader:
-                self.async_loop.remove_reader(self.sync_serial.fileno())
-                self._has_reader = False
-
-        def _ensure_writer(self):
-            if (not self._has_writer) and (not self._closing):
-                self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
-                self._has_writer = True
-
-        def _remove_writer(self):
-            if self._has_writer:
-                self.async_loop.remove_writer(self.sync_serial.fileno())
-                self._has_writer = False
-
-    def _call_connection_lost(self, exc):
-        """Close the connection."""
-        assert self._closing
-        assert not self._has_writer
-        assert not self._has_reader
-        if self.sync_serial:
-            with contextlib.suppress(Exception):
-                self.sync_serial.flush()
-
-            self.sync_serial.close()
-            self.sync_serial = None
-        if self._protocol:
-            with contextlib.suppress(Exception):
-                self._protocol.connection_lost(exc)
-
-            self._write_buffer.clear()
-        self._write_buffer.clear()
+    def _poll_write(self):
+        if not self._has_writer:
+            return
+        if self._write_ready():
+            self._has_writer = self.async_loop.call_later(
+                self._poll_wait_time, self._poll_write
+            )
 
 
 async def create_serial_connection(loop, protocol_factory, *args, **kwargs):
     """Create a connection to a new serial port instance."""
     protocol = protocol_factory()
     transport = SerialTransport(loop, protocol, *args, **kwargs)
+    loop.call_soon(transport.setup)
     return transport, protocol
diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py
index 788818ed0..de30abacc 100644
--- a/test/sub_transport/test_basic.py
+++ b/test/sub_transport/test_basic.py
@@ -323,6 +323,7 @@ async def test_external_methods(self):
         comm.write(b"abcd")
         comm.flush()
         comm.close()
+        comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy")
         comm.abort()
         assert await create_serial_connection(
             asyncio.get_running_loop(), mock.Mock, url="dummy"