Skip to content

Commit

Permalink
New nullmodem and transport.
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Jul 28, 2023
1 parent b3e63f0 commit 30104dc
Show file tree
Hide file tree
Showing 12 changed files with 629 additions and 227 deletions.
13 changes: 13 additions & 0 deletions pymodbus/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
"""Transport."""
__all__ = [
"CommType",
"create_serial_connection",
"ModbusProtocol",
"NullModem",
"SerialTransport",
]

from pymodbus.transport.transport import CommType, ModbusProtocol, NullModem
from pymodbus.transport.transport_serial import (
SerialTransport,
create_serial_connection,
)
157 changes: 112 additions & 45 deletions pymodbus/transport/transport.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""ModbusProtocol layer."""
# mypy: disable-error-code="name-defined,union-attr"
# needed because asyncio.Server is not defined (to mypy) in v3.8.16
from __future__ import annotations

import asyncio
import dataclasses
import ssl
from contextlib import suppress
from enum import Enum
from typing import Any, Callable, Coroutine

Expand Down Expand Up @@ -124,7 +123,7 @@ def __init__(
self.is_server = is_server
self.is_closing = False

self.transport: asyncio.BaseModbusProtocol | asyncio.Server = None
self.transport: asyncio.BaseTransport = None
self.loop: asyncio.AbstractEventLoop = None
self.recv_buffer: bytes = b""
self.call_create: Callable[[], Coroutine[Any, Any, Any]] = lambda: None
Expand Down Expand Up @@ -258,7 +257,7 @@ async def transport_listen(self) -> bool:
# ---------------------------------- #
# ModbusProtocol asyncio standard methods #
# ---------------------------------- #
def connection_made(self, transport: asyncio.BaseModbusProtocol):
def connection_made(self, transport: asyncio.BaseTransport):
"""Call from asyncio, when a connection is made.
:param transport: socket etc. representing the connection.
Expand Down Expand Up @@ -298,10 +297,23 @@ def datagram_received(self, data: bytes, addr: tuple):
self.sent_buffer = b""
if not data:
return
Log.debug("recv: {} addr={}", data, ":hex", addr)
Log.debug(
"recv: {} old_data: {} addr={}",
data,
":hex",
self.recv_buffer,
":hex",
addr,
)
self.recv_buffer += data
cut = self.callback_data(self.recv_buffer, addr=addr)
self.recv_buffer = self.recv_buffer[cut:]
if self.recv_buffer:
Log.debug(
"recv, unused data waiting for next packet: {}",
self.recv_buffer,
":hex",
)

def eof_received(self):
"""Accept other end terminates connection."""
Expand Down Expand Up @@ -342,11 +354,11 @@ def transport_send(self, data: bytes, addr: tuple = None) -> None:
self.sent_buffer = data
if self.comm_params.comm_type == CommType.UDP:
if addr:
self.transport.sendto(data, addr=addr)
self.transport.sendto(data, addr=addr) # type: ignore[attr-defined]
else:
self.transport.sendto(data)
self.transport.sendto(data) # type: ignore[attr-defined]
else:
self.transport.write(data)
self.transport.write(data) # type: ignore[attr-defined]

def transport_close(self, intern: bool = False, reconnect: bool = False) -> None:
"""Close connection.
Expand Down Expand Up @@ -392,26 +404,11 @@ async def create_nullmodem(self, port):
"""Bypass create_ and use null modem"""
if self.is_server:
# Listener object
self.transport = NullModem(self)
NullModem.listener_new_connection[port] = self.handle_new_connection
self.transport = NullModem.set_listener(port, self)
return self.transport, self

# connect object
client_protocol = self.handle_new_connection()
try:
server_protocol = NullModem.listener_new_connection[port]()
except KeyError as exc:
raise asyncio.TimeoutError(
f"No listener on port {self.comm_params.port} for connect"
) from exc

client_transport = NullModem(client_protocol)
server_transport = NullModem(server_protocol)
client_transport.other_transport = server_transport
server_transport.other_transport = client_transport
client_protocol.connection_made(client_transport)
server_protocol.connection_made(server_transport)
return client_transport, client_protocol
return NullModem.set_connection(port, self)

def handle_new_connection(self):
"""Handle incoming connect."""
Expand Down Expand Up @@ -468,46 +465,117 @@ class NullModem(asyncio.DatagramTransport, asyncio.Transport):
(Allowing tests to be shortcut without actual network calls)
"""

listener_new_connection: dict[int, ModbusProtocol] = {}
listeners: dict[int, ModbusProtocol] = {}
connections: dict[NullModem, int] = {}

def __init__(self, protocol: ModbusProtocol):
def __init__(self, protocol: ModbusProtocol, listen: int = None) -> None:
"""Create half part of null modem"""
asyncio.DatagramTransport.__init__(self)
asyncio.Transport.__init__(self)
self.other: NullModem = None
self.protocol: ModbusProtocol | asyncio.BaseProtocol = protocol
self.protocol: ModbusProtocol = protocol
self.serving: asyncio.Future = asyncio.Future()
self.other_transport: NullModem = None
self.other_modem: NullModem = None
self.listen = listen
self.manipulator: Callable[[bytes], list[bytes]] = None
self._is_closing = False

# -------------------------- #
# external nullmodem methods #
# -------------------------- #
@classmethod
def set_listener(cls, port: int, parent: ModbusProtocol) -> NullModem:
"""Register listener."""
if port in cls.listeners:
raise AssertionError(f"Port {port} already listening !")
cls.listeners[port] = parent
return NullModem(parent, listen=port)

@classmethod
def set_connection(
cls, port: int, parent: ModbusProtocol
) -> tuple[NullModem, ModbusProtocol]:
"""Connect to listener."""
if port not in cls.listeners:
raise asyncio.TimeoutError(f"Port {port} not being listened on !")

client_protocol = parent.handle_new_connection()
server_protocol = NullModem.listeners[port].handle_new_connection()
client_transport = NullModem(client_protocol)
server_transport = NullModem(server_protocol)
cls.connections[client_transport] = port
cls.connections[server_transport] = -port
client_transport.other_modem = server_transport
server_transport.other_modem = client_transport
client_protocol.connection_made(client_transport)
server_protocol.connection_made(server_transport)
return client_transport, client_protocol

def set_manipulator(self, function: Callable[[bytes], list[bytes]]) -> None:
"""Register a manipulator."""
self.manipulator = function

@classmethod
def is_dirty(cls):
"""Check if everything is closed."""
dirty = False
if cls.connections:
Log.error(
"NullModem_FATAL missing close on port {} connect()",
[str(key) for key in cls.connections.values()],
)
dirty = True
if cls.listeners:
Log.error(
"NullModem_FATAL missing close on port {} listen()",
[str(value) for value in cls.listeners],
)
dirty = True
return dirty

# ---------------- #
# external methods #
# ---------------- #

def close(self):
def close(self) -> None:
"""Close null modem"""
if self._is_closing:
return
self._is_closing = True
if not self.serving.done():
self.serving.set_result(True)
if self.other_transport:
self.other_transport.other_transport = None
self.other_transport.protocol.connection_lost(None)
self.other_transport = None
if self.listen:
del self.listeners[self.listen]
return
if self.connections:
with suppress(KeyError):
del self.connections[self]
if self.other_modem:
self.other_modem.other_modem = None
self.other_modem.close()
self.other_modem = None
if self.protocol:
self.protocol.connection_lost(None)

def sendto(self, data: bytes, _addr: Any = None):
def sendto(self, data: bytes, _addr: Any = None) -> None:
"""Send datagrame"""
return self.write(data)
self.write(data)

def write(self, data: bytes):
def write(self, data: bytes) -> None:
"""Send data"""
self.other_transport.protocol.data_received(data)
if not self.manipulator:
self.other_modem.protocol.data_received(data)
return
data_manipulated = self.manipulator(data)
for part in data_manipulated:
self.other_modem.protocol.data_received(part)

async def serve_forever(self):
async def serve_forever(self) -> None:
"""Serve forever"""
await self.serving

# ---------------- #
# Abstract methods #
# ---------------- #
# ------------- #
# Dummy methods #
# ------------- #
def abort(self) -> None:
"""Abort connection."""
self.close()
Expand Down Expand Up @@ -536,11 +604,10 @@ def get_protocol(self) -> ModbusProtocol | asyncio.BaseProtocol:

def set_protocol(self, protocol: asyncio.BaseProtocol) -> None:
"""Set current protocol."""
self.protocol = protocol

def is_closing(self) -> bool:
"""Return true if closing"""
return False
return self._is_closing

def is_reading(self) -> bool:
"""Return true if read is active."""
Expand Down
16 changes: 16 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ def pytest_configure():
# -----------------------------------------------------------------------#
# Generic fixtures
# -----------------------------------------------------------------------#
BASE_PORTS = {
"TestBasicModbusProtocol": 8100,
"TestBasicSerial": 8200,
"TestCommModbusProtocol": 8300,
"TestCommNullModem": 8400,
"TestExamples": 8500,
"TestModbusProtocol": 8600,
"TestNullModem": 8700,
"TestReconnectModbusProtocol": 8800,
}


@pytest.fixture(name="base_port", scope="package")
def get_base_ports():
"""Return base_ports"""
return BASE_PORTS


class MockContext(ModbusBaseSlaveContext):
Expand Down
Loading

0 comments on commit 30104dc

Please sign in to comment.