Skip to content

Commit

Permalink
Reduce transport_serial (#1807)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen authored Oct 10, 2023
1 parent f06718d commit 39177d7
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 137 deletions.
219 changes: 82 additions & 137 deletions pymodbus/transport/transport_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,62 @@ def __init__(self, loop, protocol, *args, **kwargs):
self.async_loop = loop
self._protocol: asyncio.BaseProtocol = protocol
self.sync_serial = serial.serial_for_url(*args, **kwargs)
self._closing = False
self._write_buffer = []
self.set_write_buffer_limits()
self._has_reader = False
self._has_writer = False
self._poll_wait_time = 0.0005

# Asynchronous I/O requires non-blocking devices
self.sync_serial.timeout = 0
self.sync_serial.write_timeout = 0
loop.call_soon(protocol.connection_made, self)
loop.call_soon(self._ensure_reader)

def setup(self):
"""Prepare to read/write"""
self.async_loop.call_soon(self._protocol.connection_made, self)
if os.name == "nt":
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
else:
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
self._has_reader = True

def close(self, exc=None):
"""Close the transport gracefully."""
if not self.sync_serial:
return
with contextlib.suppress(Exception):
self.sync_serial.flush()

if self._has_reader:
if os.name == "nt":
self._has_reader.cancel()
else:
self.async_loop.remove_reader(self.sync_serial.fileno())
self._has_reader = False
self.flush()
self.sync_serial.close()
self.sync_serial = None
with contextlib.suppress(Exception):
self._protocol.connection_lost(exc)

def write(self, data):
"""Write some data to the transport."""
self._write_buffer.append(data)
if not self._has_writer:
if os.name == "nt":
self._has_writer = self.async_loop.call_soon(self._poll_write)
else:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
self._has_writer = True

def flush(self):
"""Clear output buffer and stops any more data being written"""
if self._has_writer:
if os.name == "nt":
self._has_writer.cancel()
else:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._has_writer = False
self._write_buffer.clear()

# ------------------------------------------------
# Dummy methods needed to please asyncio.Transport.
Expand Down Expand Up @@ -75,160 +119,61 @@ def pause_reading(self):
def resume_reading(self):
"""Resume receiver."""

# ------------------------------------------------

def is_closing(self):
"""Return True if the transport is closing or closed."""
return self._closing
return False

def close(self):
"""Close the transport gracefully."""
if self._closing:
return
self._closing = True
self._remove_reader()
self._remove_writer()
self.async_loop.call_soon(self._call_connection_lost, None)
def abort(self):
"""Close the transport immediately."""
self.close()

# ------------------------------------------------

def _read_ready(self):
"""Test if there are data waiting."""
try:
data = self.sync_serial.read(1024)
except serial.SerialException as exc:
self.async_loop.call_soon(self._call_connection_lost, exc)
self.close()
else:
if data:
if data := self.sync_serial.read(1024):
self._protocol.data_received(data)

def write(self, data):
"""Write some data to the transport."""
if self._closing:
return

self._write_buffer.append(data)
self._ensure_writer()

def abort(self):
"""Close the transport immediately."""
self.close()

def flush(self):
"""Clear output buffer and stops any more data being written"""
self._remove_writer()
self._write_buffer.clear()
except serial.SerialException as exc:
self.close(exc=exc)

def _write_ready(self):
"""Asynchronously write buffered data."""
data = b"".join(self._write_buffer)
assert data, "Write buffer should not be empty"

self._write_buffer.clear()

try:
nlen = self.sync_serial.write(data)
if nlen := self.sync_serial.write(data) < len(data):
self._write_buffer = data[nlen:]
return True
self.flush()
except (BlockingIOError, InterruptedError):
self._write_buffer.append(data)
return True
except serial.SerialException as exc:
self.async_loop.call_soon(self._call_connection_lost, exc)
self.abort()
else:
if nlen == len(data):
assert not self.get_write_buffer_size()
self._remove_writer()
if self._closing and not self.get_write_buffer_size():
self.close()
return

assert 0 <= nlen < len(data)
data = data[nlen:]
self._write_buffer.append(data) # Try again later
assert self._has_writer

if os.name == "nt":

def _poll_read(self):
if self._has_reader and not self._closing:
try:
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
if self.sync_serial.in_waiting:
self._read_ready()
except serial.SerialException as exc:
self.async_loop.call_soon(self._call_connection_lost, exc)
self.abort()

def _ensure_reader(self):
if not self._has_reader and not self._closing:
self.close(exc=exc)
return False

def _poll_read(self):
if self._has_reader:
try:
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
if self.sync_serial.in_waiting:
self._read_ready()
except serial.SerialException as exc:
self.close(exc=exc)

def _remove_reader(self):
if self._has_reader:
self._has_reader.cancel()
self._has_reader = False

def _poll_write(self):
if self._has_writer and not self._closing:
self._has_writer = self.async_loop.call_later(
self._poll_wait_time, self._poll_write
)
self._write_ready()

def _ensure_writer(self):
if not self._has_writer and not self._closing:
self._has_writer = self.async_loop.call_soon(self._poll_write)

def _remove_writer(self):
if self._has_writer:
self._has_writer.cancel()
self._has_writer = False

else:

def _ensure_reader(self):
if (not self._has_reader) and (not self._closing):
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
self._has_reader = True

def _remove_reader(self):
if self._has_reader:
self.async_loop.remove_reader(self.sync_serial.fileno())
self._has_reader = False

def _ensure_writer(self):
if (not self._has_writer) and (not self._closing):
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
self._has_writer = True

def _remove_writer(self):
if self._has_writer:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._has_writer = False

def _call_connection_lost(self, exc):
"""Close the connection."""
assert self._closing
assert not self._has_writer
assert not self._has_reader
if self.sync_serial:
with contextlib.suppress(Exception):
self.sync_serial.flush()

self.sync_serial.close()
self.sync_serial = None
if self._protocol:
with contextlib.suppress(Exception):
self._protocol.connection_lost(exc)

self._write_buffer.clear()
self._write_buffer.clear()
def _poll_write(self):
if not self._has_writer:
return
if self._write_ready():
self._has_writer = self.async_loop.call_later(
self._poll_wait_time, self._poll_write
)


async def create_serial_connection(loop, protocol_factory, *args, **kwargs):
"""Create a connection to a new serial port instance."""
protocol = protocol_factory()
transport = SerialTransport(loop, protocol, *args, **kwargs)
loop.call_soon(transport.setup)
return transport, protocol
1 change: 1 addition & 0 deletions test/sub_transport/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ async def test_external_methods(self):
comm.write(b"abcd")
comm.flush()
comm.close()
comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy")
comm.abort()
assert await create_serial_connection(
asyncio.get_running_loop(), mock.Mock, url="dummy"
Expand Down

0 comments on commit 39177d7

Please sign in to comment.