From a513f0a6a3515d1b054f6756b9b5d30cbb96cb0f Mon Sep 17 00:00:00 2001 From: jan iversen Date: Sun, 2 Jul 2023 21:49:34 +0200 Subject: [PATCH 1/4] Simplify clients. --- pymodbus/client/base.py | 48 ++++----- pymodbus/client/serial.py | 49 +++++---- pymodbus/client/tcp.py | 22 ++-- pymodbus/client/tls.py | 20 ++-- pymodbus/client/udp.py | 19 ++-- pymodbus/framer/rtu_framer.py | 4 +- pymodbus/repl/client/mclient.py | 28 ++--- pymodbus/server/async_io.py | 13 ++- pymodbus/transport/transport.py | 154 +++++++++++++++++---------- test/sub_transport/conftest.py | 41 +++++-- test/sub_transport/test_basic.py | 59 ++++------ test/sub_transport/test_comm.py | 10 +- test/sub_transport/test_reconnect.py | 16 +-- test/test_client_sync.py | 14 +-- test/test_framers.py | 2 +- 15 files changed, 272 insertions(+), 227 deletions(-) diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index dfd999be2..dbf0976e1 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -52,9 +52,6 @@ class ModbusBaseClient(ModbusClientMixin, ModbusProtocol): class _params: """Parameter class.""" - host: str = None - port: str | int = None - timeout: float = None retries: int = None retry_on_empty: bool = None close_comm_on_error: bool = None @@ -62,10 +59,6 @@ class _params: broadcast_enable: bool = None reconnect_delay: int = None - baudrate: int = None - bytesize: int = None - parity: str = None - stopbits: int = None handle_local_echo: bool = None source_address: tuple[str, int] = None @@ -89,28 +82,31 @@ def __init__( # pylint: disable=too-many-arguments """Initialize a client instance.""" ModbusClientMixin.__init__(self) self.use_sync = kwargs.get("use_sync", False) + setup_params = CommParams( + comm_type=kwargs.get("CommType"), + comm_name="comm", + source_address=kwargs.get("source_address", ("localhost", 0)), + reconnect_delay=reconnect_delay, + reconnect_delay_max=reconnect_delay_max, + timeout_connect=timeout, + host=kwargs.get("host", None), + port=kwargs.get("port", None), + sslctx=kwargs.get("sslctx", None), + baudrate=kwargs.get("baudrate", None), + bytesize=kwargs.get("bytesize", None), + parity=kwargs.get("parity", None), + stopbits=kwargs.get("stopbits", None), + ) if not self.use_sync: ModbusProtocol.__init__( self, - CommParams( - comm_type=kwargs.get("CommType"), - comm_name="comm", - reconnect_delay=reconnect_delay, - reconnect_delay_max=reconnect_delay_max, - timeout_connect=timeout, - host=kwargs.get("host", None), - port=kwargs.get("port", None), - sslctx=kwargs.get("sslctx", None), - baudrate=kwargs.get("baudrate", None), - bytesize=kwargs.get("bytesize", None), - parity=kwargs.get("parity", None), - stopbits=kwargs.get("stopbits", None), - ), + setup_params, False, ) + else: + self.comm_params = setup_params self.framer = framer self.params = self._params() - self.params.timeout = float(timeout) self.params.retries = int(retries) self.params.retry_on_empty = bool(retry_on_empty) self.params.close_comm_on_error = bool(close_comm_on_error) @@ -199,7 +195,9 @@ async def async_execute(self, request=None): resp = b"Broadcast write sent - no response expected" else: try: - resp = await asyncio.wait_for(req, timeout=self.params.timeout) + resp = await asyncio.wait_for( + req, timeout=self.comm_params.timeout_connect + ) except asyncio.exceptions.TimeoutError: self.close(reconnect=True) raise @@ -315,4 +313,6 @@ def __str__(self): :returns: The string representation """ - return f"{self.__class__.__name__} {self.params.host}:{self.params.port}" + return ( + f"{self.__class__.__name__} {self.comm_params.host}:{self.comm_params.port}" + ) diff --git a/pymodbus/client/serial.py b/pymodbus/client/serial.py index cb32131cf..50f7597da 100644 --- a/pymodbus/client/serial.py +++ b/pymodbus/client/serial.py @@ -68,11 +68,6 @@ def __init__( stopbits=stopbits, **kwargs, ) - self.params.port = port - self.params.baudrate = baudrate - self.params.bytesize = bytesize - self.params.parity = parity - self.params.stopbits = stopbits self.params.handle_local_echo = handle_local_echo @property @@ -138,18 +133,23 @@ def __init__( """Initialize Modbus Serial Client.""" self.transport = None kwargs["use_sync"] = True - super().__init__(framer=framer, **kwargs) - self.params.port = port - self.params.baudrate = baudrate - self.params.bytesize = bytesize - self.params.parity = parity - self.params.stopbits = stopbits + ModbusBaseClient.__init__( + self, + framer=framer, + CommType=CommType.SERIAL, + host=port, + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + **kwargs, + ) self.params.handle_local_echo = handle_local_echo self.socket = None self.last_frame_end = None - self._t0 = float(1 + 8 + 2) / self.params.baudrate + self._t0 = float(1 + 8 + 2) / self.comm_params.baudrate """ The minimum delay is 0.01s and the maximum can be set to 0.05s. @@ -161,7 +161,7 @@ def __init__( else 0.05 ) - if self.params.baudrate > 19200: + if self.comm_params.baudrate > 19200: self.silent_interval = 1.75 / 1000 # ms else: self.inter_char_timeout = 1.5 * self._t0 @@ -179,12 +179,12 @@ def connect(self): # pylint: disable=invalid-overridden-method return True try: self.socket = serial.serial_for_url( - self.params.port, - timeout=self.params.timeout, - bytesize=self.params.bytesize, - stopbits=self.params.stopbits, - baudrate=self.params.baudrate, - parity=self.params.parity, + self.comm_params.host, + timeout=self.comm_params.timeout_connect, + bytesize=self.comm_params.bytesize, + stopbits=self.comm_params.stopbits, + baudrate=self.comm_params.baudrate, + parity=self.comm_params.parity, ) if isinstance(self.framer, ModbusRtuFramer): if self.params.strict: @@ -244,10 +244,13 @@ def _wait_for_data(self): """Wait for data.""" size = 0 more_data = False - if self.params.timeout is not None and self.params.timeout: + if ( + self.comm_params.timeout_connect is not None + and self.comm_params.timeout_connect + ): condition = partial( lambda start, timeout: (time.time() - start) <= timeout, - timeout=self.params.timeout, + timeout=self.comm_params.timeout_connect, ) else: condition = partial(lambda dummy1, dummy2: True, dummy2=None) @@ -286,11 +289,11 @@ def is_socket_open(self): def __str__(self): """Build a string representation of the connection.""" - return f"ModbusSerialClient({self.framer} baud[{self.params.baudrate}])" + return f"ModbusSerialClient({self.framer} baud[{self.comm_params.baudrate}])" def __repr__(self): """Return string representation.""" return ( f"<{self.__class__.__name__} at {hex(id(self))} socket={self.socket}, " - f"framer={self.framer}, timeout={self.params.timeout}>" + f"framer={self.framer}, timeout={self.comm_params.timeout_connect}>" ) diff --git a/pymodbus/client/tcp.py b/pymodbus/client/tcp.py index 765201044..a99700734 100644 --- a/pymodbus/client/tcp.py +++ b/pymodbus/client/tcp.py @@ -54,8 +54,6 @@ def __init__( port=port, **kwargs, ) - self.params.host = host - self.params.port = port self.params.source_address = source_address async def connect(self) -> bool: @@ -111,11 +109,11 @@ def __init__( **kwargs: Any, ) -> None: """Initialize Modbus TCP Client.""" + if "CommType" not in kwargs: + kwargs["CommType"] = CommType.TCP kwargs["use_sync"] = True self.transport = None super().__init__(framer=framer, host=host, port=port, **kwargs) - self.params.host = host - self.params.port = port self.params.source_address = source_address self.socket = None @@ -130,8 +128,8 @@ def connect(self): # pylint: disable=invalid-overridden-method return True try: self.socket = socket.create_connection( - (self.params.host, self.params.port), - timeout=self.params.timeout, + (self.comm_params.host, self.comm_params.port), + timeout=self.comm_params.timeout_connect, source_address=self.params.source_address, ) Log.debug( @@ -141,8 +139,8 @@ def connect(self): # pylint: disable=invalid-overridden-method except OSError as msg: Log.error( "Connection to ({}, {}) failed: {}", - self.params.host, - self.params.port, + self.comm_params.host, + self.comm_params.port, msg, ) self.close() @@ -157,7 +155,7 @@ def close(self): # pylint: disable=arguments-differ def _check_read_buffer(self): """Check read buffer.""" time_ = time.time() - end = time_ + self.params.timeout + end = time_ + self.comm_params.timeout_connect data = None ready = select.select([self.socket], [], [], end - time_) if ready[0]: @@ -193,7 +191,7 @@ def recv(self, size): # less than the expected size. self.socket.setblocking(0) - timeout = self.params.timeout + timeout = self.comm_params.timeout_connect # If size isn't specified read up to 4096 bytes at a time. if size is None: @@ -270,11 +268,11 @@ def __str__(self): :returns: The string representation """ - return f"ModbusTcpClient({self.params.host}:{self.params.port})" + return f"ModbusTcpClient({self.comm_params.host}:{self.comm_params.port})" def __repr__(self): """Return string representation.""" return ( f"<{self.__class__.__name__} at {hex(id(self))} socket={self.socket}, " - f"ipaddr={self.params.host}, port={self.params.port}, timeout={self.params.timeout}>" + f"ipaddr={self.comm_params.host}, port={self.comm_params.port}, timeout={self.comm_params.timeout_connect}>" ) diff --git a/pymodbus/client/tls.py b/pymodbus/client/tls.py index 07fde2539..02f941636 100644 --- a/pymodbus/client/tls.py +++ b/pymodbus/client/tls.py @@ -127,7 +127,9 @@ def __init__( ): """Initialize Modbus TLS Client.""" self.transport = None - super().__init__(host, port=port, framer=framer, **kwargs) + super().__init__( + host, CommType=CommType.TLS, port=port, framer=framer, **kwargs + ) self.sslctx = CommParams.generate_ssl( False, certfile, keyfile, password, sslctx=sslctx ) @@ -147,15 +149,15 @@ def connect(self): if self.params.source_address: sock.bind(self.params.source_address) self.socket = self.sslctx.wrap_socket( - sock, server_side=False, server_hostname=self.params.host + sock, server_side=False, server_hostname=self.comm_params.host ) - self.socket.settimeout(self.params.timeout) - self.socket.connect((self.params.host, self.params.port)) + self.socket.settimeout(self.comm_params.timeout_connect) + self.socket.connect((self.comm_params.host, self.comm_params.port)) except OSError as msg: Log.error( "Connection to ({}, {}) failed: {}", - self.params.host, - self.params.port, + self.comm_params.host, + self.comm_params.port, msg, ) self.close() @@ -163,12 +165,12 @@ def connect(self): def __str__(self): """Build a string representation of the connection.""" - return f"ModbusTlsClient({self.params.host}:{self.params.port})" + return f"ModbusTlsClient({self.comm_params.host}:{self.comm_params.port})" def __repr__(self): """Return string representation.""" return ( f"<{self.__class__.__name__} at {hex(id(self))} socket={self.socket}, " - f"ipaddr={self.params.host}, port={self.params.port}, sslctx={self.sslctx}, " - f"timeout={self.params.timeout}>" + f"ipaddr={self.comm_params.host}, port={self.comm_params.port}, sslctx={self.sslctx}, " + f"timeout={self.comm_params.timeout_connect}>" ) diff --git a/pymodbus/client/udp.py b/pymodbus/client/udp.py index 19d65a4d8..ce953a21b 100644 --- a/pymodbus/client/udp.py +++ b/pymodbus/client/udp.py @@ -54,7 +54,6 @@ def __init__( ModbusBaseClient.__init__( self, framer=framer, CommType=CommType.UDP, host=host, port=port, **kwargs ) - self.params.port = port self.params.source_address = source_address @property @@ -117,9 +116,9 @@ def __init__( """Initialize Modbus UDP Client.""" kwargs["use_sync"] = True self.transport = None - super().__init__(framer=framer, **kwargs) - self.params.host = host - self.params.port = port + super().__init__( + framer=framer, port=port, host=host, CommType=CommType.UDP, **kwargs + ) self.params.source_address = source_address self.socket = None @@ -137,9 +136,9 @@ def connect(self): # pylint: disable=invalid-overridden-method if self.socket: return True try: - family = ModbusUdpClient._get_address_family(self.params.host) + family = ModbusUdpClient._get_address_family(self.comm_params.host) self.socket = socket.socket(family, socket.SOCK_DGRAM) - self.socket.settimeout(self.params.timeout) + self.socket.settimeout(self.comm_params.timeout_connect) except OSError as exc: Log.error("Unable to create udp socket {}", exc) self.close() @@ -161,7 +160,9 @@ def send(self, request): if not self.socket: raise ConnectionException(str(self)) if request: - return self.socket.sendto(request, (self.params.host, self.params.port)) + return self.socket.sendto( + request, (self.comm_params.host, self.comm_params.port) + ) return 0 def recv(self, size): @@ -183,11 +184,11 @@ def is_socket_open(self): def __str__(self): """Build a string representation of the connection.""" - return f"ModbusUdpClient({self.params.host}:{self.params.port})" + return f"ModbusUdpClient({self.comm_params.host}:{self.comm_params.port})" def __repr__(self): """Return string representation.""" return ( f"<{self.__class__.__name__} at {hex(id(self))} socket={self.socket}, " - f"ipaddr={self.params.host}, port={self.params.port}, timeout={self.params.timeout}>" + f"ipaddr={self.comm_params.host}, port={self.comm_params.port}, timeout={self.comm_params.timeout_connect}>" ) diff --git a/pymodbus/framer/rtu_framer.py b/pymodbus/framer/rtu_framer.py index 3f75c385e..22077bf76 100644 --- a/pymodbus/framer/rtu_framer.py +++ b/pymodbus/framer/rtu_framer.py @@ -248,7 +248,7 @@ def sendPacket(self, message): :return: """ start = time.time() - timeout = start + self.client.params.timeout + timeout = start + self.client.comm_params.timeout_connect while self.client.state != ModbusTransactionState.IDLE: if self.client.state == ModbusTransactionState.TRANSACTION_COMPLETE: timestamp = round(time.time(), 6) @@ -272,7 +272,7 @@ def sendPacket(self, message): elif self.client.state == ModbusTransactionState.RETRYING: # Simple lets settle down!!! # To check for higher baudrates - time.sleep(self.client.params.timeout) + time.sleep(self.client.comm_params.timeout_connect) break elif time.time() > timeout: Log.debug( diff --git a/pymodbus/repl/client/mclient.py b/pymodbus/repl/client/mclient.py index 956b69293..0d6e49dc5 100644 --- a/pymodbus/repl/client/mclient.py +++ b/pymodbus/repl/client/mclient.py @@ -566,14 +566,14 @@ def get_port(self): :return: Current Serial port """ - return self.params.port + return self.comm_params.port def set_port(self, value): """Set serial Port setter. :param value: New port """ - self.params.port = value + self.comm_params.port = value if self.is_socket_open(): self.close() @@ -598,7 +598,7 @@ def get_bytesize(self): :return: Current bytesize """ - return self.params.bytesize + return self.comm_params.bytesize def set_bytesize(self, value): """Set Byte size. @@ -606,7 +606,7 @@ def set_bytesize(self, value): :param value: Possible values (5, 6, 7, 8) """ - self.params.bytesize = int(value) + self.comm_params.bytesize = int(value) if self.is_socket_open(): self.close() @@ -631,14 +631,14 @@ def get_baudrate(self): :return: Current baudrate """ - return self.params.baudrate + return self.comm_params.baudrate def set_baudrate(self, value): """Set baudrate setter. :param value: """ - self.params.baudrate = int(value) + self.comm_params.baudrate = int(value) if self.is_socket_open(): self.close() @@ -647,14 +647,14 @@ def get_timeout(self): :return: Current read imeout. """ - return self.params.timeout + return self.comm_params.timeout_connect def set_timeout(self, value): """Read timeout setter. :param value: Read Timeout in seconds """ - self.params.timeout = float(value) + self.comm_params.timeout_connect = float(value) if self.is_socket_open(): self.close() @@ -664,12 +664,12 @@ def get_serial_settings(self): :return: Current Serial settings as dict. """ return { - "baudrate": self.params.baudrate, - "port": self.params.port, - "parity": self.params.parity, - "stopbits": self.params.stopbits, - "bytesize": self.params.bytesize, - "read timeout": self.params.timeout, + "baudrate": self.comm_params.baudrate, + "port": self.comm_params.port, + "parity": self.comm_params.parity, + "stopbits": self.comm_params.stopbits, + "bytesize": self.comm_params.bytesize, + "read timeout": self.comm_params.timeout_connect, "t1.5": self.inter_char_timeout, "t3.5": self.silent_interval, } diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 2285ef85e..303361397 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -50,8 +50,8 @@ def __init__(self, owner): reconnect_delay=0.0, reconnect_delay_max=0.0, timeout_connect=0.0, - host=owner.comm_params.host, - port=owner.comm_params.port, + host=owner.comm_params.source_address[0], + port=owner.comm_params.source_address[1], ) super().__init__(params, True) self.server = owner @@ -338,8 +338,8 @@ def __init__( timeout_connect=0.0, ), ) - params.host = address[0] - params.port = address[1] + params.source_address = address + super().__init__( params, True, @@ -515,8 +515,7 @@ def __init__( CommParams( comm_type=CommType.UDP, comm_name="server_listener", - host=address[0], - port=address[1], + source_address=address, reconnect_delay=0.0, reconnect_delay_max=0.0, timeout_connect=0.0, @@ -619,7 +618,7 @@ def __init__( reconnect_delay=kwargs.get("reconnect_delay", 2), reconnect_delay_max=0.0, timeout_connect=kwargs.get("timeout", 3), - host=kwargs.get("port", 0), + source_address=(kwargs.get("port", 0), 0), bytesize=kwargs.get("bytesize", 8), parity=kwargs.get("parity", "N"), baudrate=kwargs.get("baudrate", 19200), diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index 43a9da1da..50cfc076b 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -35,12 +35,9 @@ class CommParams: reconnect_delay: float = None reconnect_delay_max: float = None timeout_connect: float = None - - # tcp / tls / udp / serial - host: str = None - - # tcp / tls / udp - port: int = None + host: str = "localhost" + port: int = 0 + source_address: tuple[str, int] = ("localhost", 0) # tls sslctx: ssl.SSLContext = None @@ -94,6 +91,20 @@ class ModbusProtocol(asyncio.BaseProtocol): Contains high level methods like reconnect. + Host/Port/SourceAddress explanation: + - SourceAddress: + - server: (host, port) to listen on (default is ("localhost", 502/802)) + - server serial: (host, _) to open/connect and listen on + - client: (Bind local part to interface (default is local interface) + - client serial: (host, _) to open/connect and listen on + - Host + - Server: not used + - Client serial: port string to use for connecting + - Client others: remote host to connect to + - Port + - Server/Client serial: not used + - Client others: remote port to connect to + The class is designed to take care of differences between the different transport mediums, and provide a neutral interface for the upper layers. """ @@ -106,57 +117,52 @@ def __init__( """Initialize a transport instance. :param params: parameter dataclass - :param is_server: true if object act as a server (listen allowed) - :param callback_connected: Called when connection is established - :param callback_disconnected: Called when connection is disconnected - :param callback_data: Called when data is received + :param is_server: true if object act as a server (listen/connect) """ self.comm_params = params.copy() self.is_server = is_server + self.is_closing = False - self.reconnect_delay_current: float = 0.0 - self.listener: ModbusProtocol = None self.transport: asyncio.BaseModbusProtocol | asyncio.Server = None self.loop: asyncio.AbstractEventLoop = None - self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" self.call_create: Callable[[], Coroutine[Any, Any, Any]] = lambda: None - self.active_connections: dict[str, ModbusProtocol] = {} - self.unique_id: str = str(id(self)) + if self.is_server: + self.active_connections: dict[str, ModbusProtocol] = {} + else: + self.listener: ModbusProtocol = None + self.unique_id: str = str(id(self)) + self.reconnect_task: asyncio.Task = None + self.reconnect_delay_current: float = 0.0 # ModbusProtocol specific setup - if self.comm_params.host.startswith(NULLMODEM_HOST): - if self.comm_params.comm_type == CommType.SERIAL: - self.comm_params.port = int(self.comm_params.host[9:].split(":")[1]) - self.call_create = self.create_nullmodem + if self.comm_params.comm_type == CommType.SERIAL: + self.init_correct_serial() + if self.init_check_nullmodem(): return + if self.comm_params.comm_type == CommType.SERIAL: - if self.comm_params.host.startswith("socket:") and is_server: - parts = self.comm_params.host[9:].split(":") - self.comm_params.host = parts[0] - self.comm_params.port = int(parts[1]) - self.comm_params.comm_type = CommType.TCP - else: - self.call_create = lambda: create_serial_connection( - self.loop, - self.handle_new_connection, - self.comm_params.host, - baudrate=self.comm_params.baudrate, - bytesize=self.comm_params.bytesize, - parity=self.comm_params.parity, - stopbits=self.comm_params.stopbits, - timeout=self.comm_params.timeout_connect, - ) - return + self.call_create = lambda: create_serial_connection( + self.loop, + self.handle_new_connection, + self.comm_params.host, + baudrate=self.comm_params.baudrate, + bytesize=self.comm_params.bytesize, + parity=self.comm_params.parity, + stopbits=self.comm_params.stopbits, + timeout=self.comm_params.timeout_connect, + ) + return if self.comm_params.comm_type == CommType.UDP: if is_server: self.call_create = lambda: self.loop.create_datagram_endpoint( self.handle_new_connection, - local_addr=(self.comm_params.host, self.comm_params.port), + local_addr=self.comm_params.source_address, ) else: self.call_create = lambda: self.loop.create_datagram_endpoint( self.handle_new_connection, + local_addr=self.comm_params.source_address, remote_addr=(self.comm_params.host, self.comm_params.port), ) return @@ -164,8 +170,8 @@ def __init__( if is_server: self.call_create = lambda: self.loop.create_server( self.handle_new_connection, - self.comm_params.host, - self.comm_params.port, + self.comm_params.source_address[0], + self.comm_params.source_address[1], ssl=self.comm_params.sslctx, reuse_address=True, start_serving=True, @@ -175,14 +181,42 @@ def __init__( self.handle_new_connection, self.comm_params.host, self.comm_params.port, + local_addr=self.comm_params.source_address, ssl=self.comm_params.sslctx, ) + def init_correct_serial(self) -> None: + """Split host for serial if needed.""" + if self.is_server: + host = self.comm_params.source_address[0] + if host.startswith("socket"): + parts = host[9:].split(":") + self.comm_params.source_address = (parts[0], int(parts[1])) + self.comm_params.comm_type = CommType.TCP + elif host.startswith(NULLMODEM_HOST): + self.comm_params.source_address = (host, int(host[9:].split(":")[1])) + return + if self.comm_params.host.startswith(NULLMODEM_HOST): + self.comm_params.port = int(self.comm_params.host[9:].split(":")[1]) + + def init_check_nullmodem(self) -> bool: + """Check if nullmodem is needed.""" + if self.comm_params.host.startswith(NULLMODEM_HOST): + port = self.comm_params.port + elif self.comm_params.source_address[0].startswith(NULLMODEM_HOST): + port = self.comm_params.source_address[1] + else: + return False + + self.call_create = lambda: self.create_nullmodem(port) + return True + async def transport_connect(self) -> bool: """Handle generic connect and call on to specific transport connect.""" Log.debug("Connecting {}", self.comm_params.comm_name) if not self.loop: self.loop = asyncio.get_running_loop() + self.is_closing = False try: self.transport, _protocol = await asyncio.wait_for( self.call_create(), @@ -193,7 +227,7 @@ async def transport_connect(self) -> bool: OSError, ) as exc: Log.warning("Failed to connect {}", exc) - self.transport_close(reconnect=True) + self.transport_close(intern=True, reconnect=True) return False return bool(self.transport) @@ -202,13 +236,14 @@ async def transport_listen(self) -> bool: Log.debug("Awaiting connections {}", self.comm_params.comm_name) if not self.loop: self.loop = asyncio.get_running_loop() + self.is_closing = False try: self.transport = await self.call_create() if isinstance(self.transport, tuple): self.transport = self.transport[0] except OSError as exc: Log.warning("Failed to start server {}", exc) - self.transport_close() + self.transport_close(intern=True) return False return True @@ -230,11 +265,11 @@ def connection_lost(self, reason: Exception): :param reason: None or an exception object """ - if not self.transport: + if not self.transport or self.is_closing: return Log.debug("Connection lost {} due to {}", self.comm_params.comm_name, reason) - self.transport_close() - if not self.is_server: + self.transport_close(intern=True) + if not self.is_server and not self.listener: self.reconnect_task = asyncio.create_task(self.do_reconnect()) self.callback_disconnected(reason) @@ -298,28 +333,34 @@ def transport_send(self, data: bytes, addr: tuple = None) -> None: else: self.transport.write(data) - def transport_close(self, reconnect: bool = False) -> None: + def transport_close(self, intern: bool = False, reconnect: bool = False) -> None: """Close connection. + :param intern: (default false), True if called internally (temporary close) :param reconnect: (default false), try to reconnect """ + if self.is_closing: + return + if not intern: + self.is_closing = True if self.transport: if hasattr(self.transport, "abort"): self.transport.abort() self.transport.close() self.transport = None + self.recv_buffer = b"" + if self.is_server: + for _key, value in self.active_connections.items(): + value.listener = None + value.transport_close() + self.active_connections = {} + return if not reconnect and self.reconnect_task: self.reconnect_task.cancel() self.reconnect_task = None self.reconnect_delay_current = 0.0 - self.recv_buffer = b"" if self.listener: self.listener.active_connections.pop(self.unique_id) - elif self.is_server: - for _key, value in self.active_connections.items(): - value.listener = None - value.transport_close() - self.active_connections = {} def reset_delay(self) -> None: """Reset wait time before next reconnect to minimal period.""" @@ -332,20 +373,18 @@ def is_active(self) -> bool: # ---------------- # # Internal methods # # ---------------- # - async def create_nullmodem(self): + async def create_nullmodem(self, port): """Bypass create_ and use null modem""" if self.is_server: # Listener object self.transport = NullModem(self) - NullModem.listener_new_connection[ - self.comm_params.port - ] = self.handle_new_connection + NullModem.listener_new_connection[port] = self.handle_new_connection return self.transport, self # connect object client_protocol = self.handle_new_connection() try: - server_protocol = NullModem.listener_new_connection[self.comm_params.port]() + server_protocol = NullModem.listener_new_connection[port]() except KeyError as exc: raise asyncio.TimeoutError( f"No listener on port {self.comm_params.port} for connect" @@ -362,9 +401,10 @@ async def create_nullmodem(self): def handle_new_connection(self): """Handle incoming connect.""" if not self.is_server: + # Clients reuse the same object. return self - new_protocol = ModbusProtocol(self.comm_params, True) + new_protocol = ModbusProtocol(self.comm_params, False) self.active_connections[new_protocol.unique_id] = new_protocol new_protocol.listener = self return new_protocol diff --git a/test/sub_transport/conftest.py b/test/sub_transport/conftest.py index 6cba52485..be86cbe78 100644 --- a/test/sub_transport/conftest.py +++ b/test/sub_transport/conftest.py @@ -64,8 +64,27 @@ def prepare_dummy_use_host(): return "localhost" -@pytest.fixture(name="commparams") -def prepare_commparams(use_port, use_host, use_comm_type): +@pytest.fixture(name="use_cls") +def prepare_commparams_server(use_port, use_host, use_comm_type): + """Prepare CommParamsClass object.""" + if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL: + use_host = f"{NULLMODEM_HOST}:{use_port}" + return CommParams( + comm_name="test comm", + comm_type=use_comm_type, + reconnect_delay=0, + reconnect_delay_max=0, + timeout_connect=0, + source_address=(use_host, use_port), + baudrate=9600, + bytesize=8, + parity="E", + stopbits=2, + ) + + +@pytest.fixture(name="use_clc") +def prepare_commparams_client(use_port, use_host, use_comm_type): """Prepare CommParamsClass object.""" if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL: use_host = f"{NULLMODEM_HOST}:{use_port}" @@ -85,36 +104,36 @@ def prepare_commparams(use_port, use_host, use_comm_type): @pytest.fixture(name="client") -async def prepare_protocol(commparams): +async def prepare_protocol(use_clc): """Prepare transport object.""" - transport = ModbusProtocol(commparams, False) + transport = ModbusProtocol(use_clc, False) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() transport.callback_connected = mock.Mock() transport.callback_disconnected = mock.Mock() transport.callback_data = mock.Mock(return_value=0) - if commparams.comm_type == CommType.TLS: + if use_clc.comm_type == CommType.TLS: cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." - transport.comm_params.sslctx = commparams.generate_ssl( + transport.comm_params.sslctx = use_clc.generate_ssl( False, certfile=cwd + "crt", keyfile=cwd + "key" ) - if commparams.comm_type == CommType.SERIAL: + if use_clc.comm_type == CommType.SERIAL: transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}" return transport @pytest.fixture(name="server") -async def prepare_transport_server(commparams): +async def prepare_transport_server(use_cls): """Prepare transport object.""" - transport = ModbusProtocol(commparams, True) + transport = ModbusProtocol(use_cls, True) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() transport.callback_connected = mock.Mock() transport.callback_disconnected = mock.Mock() transport.callback_data = mock.Mock(return_value=0) - if commparams.comm_type == CommType.TLS: + if use_cls.comm_type == CommType.TLS: cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." - transport.comm_params.sslctx = commparams.generate_ssl( + transport.comm_params.sslctx = use_cls.generate_ssl( True, certfile=cwd + "crt", keyfile=cwd + "key" ) return transport diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py index 18ea1bca4..d5998c464 100644 --- a/test/sub_transport/test_basic.py +++ b/test/sub_transport/test_basic.py @@ -24,52 +24,37 @@ class TestBasicModbusProtocol: """Test transport module.""" @pytest.mark.parametrize("use_comm_type", COMM_TYPES) - async def test_init(self, client, server, commparams): + async def test_init(self, client, server): """Test init()""" - if commparams.comm_type == CommType.SERIAL: - client.comm_params.host = commparams.host - server.comm_params.comm_type = commparams.comm_type client.comm_params.sslctx = None - assert client.comm_params == commparams assert client.unique_id == str(id(client)) + assert not hasattr(client, "active_connections") assert not client.is_server server.comm_params.sslctx = None - assert server.comm_params == commparams - assert server.unique_id == str(id(server)) + assert not hasattr(server, "unique_id") + assert not server.active_connections assert server.is_server @pytest.mark.parametrize("use_host", [NULLMODEM_HOST]) @pytest.mark.parametrize("use_comm_type", COMM_TYPES) - async def test_init_nullmodem(self, client, server, commparams): + async def test_init_nullmodem(self, client, server): """Test init()""" - if commparams.comm_type == CommType.SERIAL: - client.comm_params.host = commparams.host - server.comm_params.comm_type = commparams.comm_type client.comm_params.sslctx = None - assert client.comm_params == commparams assert client.unique_id == str(id(client)) + assert not hasattr(client, "active_connections") assert not client.is_server - server.comm_params.sslctx = None - assert server.comm_params == commparams - assert server.unique_id == str(id(server)) + assert not hasattr(server, "unique_id") + assert not server.active_connections assert server.is_server @pytest.mark.parametrize( ("use_host", "use_comm_type"), [("socket://127.0.0.1:7001", CommType.SERIAL)] ) - async def test_init_serial(self, client, server, commparams): + async def test_init_serial(self, client, server): """Test init()""" - client.comm_params.host = commparams.host - client.comm_params.sslctx = None - server.comm_params.host = commparams.host - server.comm_params.port = commparams.port - server.comm_params.comm_type = commparams.comm_type - assert client.comm_params == commparams assert client.unique_id == str(id(client)) assert not client.is_server server.comm_params.sslctx = None - assert server.comm_params == commparams - assert server.unique_id == str(id(server)) assert server.is_server async def test_connect(self, client, dummy_protocol): @@ -89,13 +74,13 @@ async def test_listen(self, server, dummy_protocol): server.call_create.side_effect = OSError("testing") assert not await server.transport_listen() - async def test_connection_made(self, client, commparams, dummy_protocol): + async def test_connection_made(self, client, use_clc, dummy_protocol): """Test connection_made().""" client.connection_made(dummy_protocol) assert client.transport assert not client.recv_buffer assert not client.reconnect_task - assert client.reconnect_delay_current == commparams.reconnect_delay + assert client.reconnect_delay_current == use_clc.reconnect_delay client.callback_connected.assert_called_once() async def test_connection_lost(self, client, dummy_protocol): @@ -135,9 +120,9 @@ async def test_error_received(self, client): with pytest.raises(RuntimeError): client.error_received(Exception("test call")) - async def test_callbacks(self, commparams): + async def test_callbacks(self, use_clc): """Test callbacks.""" - client = ModbusProtocol(commparams, False) + client = ModbusProtocol(use_clc, False) client.callback_connected() client.callback_disconnected(Exception("test")) client.callback_data(b"abcd") @@ -158,22 +143,20 @@ async def test_transport_close(self, server, dummy_protocol): server.connection_made(dummy_protocol) server.recv_buffer = b"abc" server.reconnect_task = mock.MagicMock() - server.listener = mock.MagicMock() server.transport_close() dummy_protocol.abort.assert_called_once() dummy_protocol.close.assert_called_once() assert not server.recv_buffer - assert not server.reconnect_task - server.listener = None + await server.transport_listen() server.active_connections = {"a": dummy_protocol} server.transport_close() assert not server.active_connections - async def test_reset_delay(self, client, commparams): + async def test_reset_delay(self, client, use_clc): """Test reset_delay().""" client.reconnect_delay_current += 5.17 client.reset_delay() - assert client.reconnect_delay_current == commparams.reconnect_delay + assert client.reconnect_delay_current == use_clc.reconnect_delay async def test_is_active(self, client): """Test is_active().""" @@ -212,17 +195,17 @@ async def test_with_magic(self, client): pass client.transport_close.assert_called_once() - async def test_str_magic(self, commparams, client): + async def test_str_magic(self, use_clc, client): """Test magic.""" - assert str(client) == f"ModbusProtocol({commparams.comm_name})" + assert str(client) == f"ModbusProtocol({use_clc.comm_name})" - def test_generate_ssl(self, commparams): + def test_generate_ssl(self, use_clc): """Test ssl generattion""" with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): - sslctx = commparams.generate_ssl(True, "cert_file", "key_file") + sslctx = use_clc.generate_ssl(True, "cert_file", "key_file") assert sslctx test_value = "test igen" - assert test_value == commparams.generate_ssl( + assert test_value == use_clc.generate_ssl( True, "cert_file", "key_file", sslctx=test_value ) diff --git a/test/sub_transport/test_comm.py b/test/sub_transport/test_comm.py index ed57c384c..16cc73e2e 100644 --- a/test/sub_transport/test_comm.py +++ b/test/sub_transport/test_comm.py @@ -111,11 +111,11 @@ async def test_connected(self, client, server, use_comm_type): server_connected.transport_send(test_data, addr=addr) else: server_connected.transport_send(test_data) - await asyncio.sleep(2) + await asyncio.sleep(1) assert client.recv_buffer == test_data assert not server_connected.recv_buffer client.transport_close() - await asyncio.sleep(2) + await asyncio.sleep(1) if use_comm_type != CommType.UDP: assert not server.active_connections server.transport_close() @@ -126,7 +126,7 @@ async def test_connected(self, client, server, use_comm_type): (CommType.TCP, "localhost", BASE_PORT + 21), ], ) - async def test_connected_multiple(self, client, server, commparams): + async def test_connected_multiple(self, client, server, use_clc): """Test connection and data exchange.""" assert await server.transport_listen() assert await client.transport_connect() @@ -134,9 +134,9 @@ async def test_connected_multiple(self, client, server, commparams): assert len(server.active_connections) == 1 server_connected = list(server.active_connections.values())[0] - c2_params = commparams.copy() + c2_params = use_clc.copy() c2_params.port = client.comm_params.port + 1 - client2 = ModbusProtocol(commparams, False) + client2 = ModbusProtocol(use_clc, False) client2.callback_connected = mock.Mock() client2.callback_disconnected = mock.Mock() client2.callback_data = mock.Mock(return_value=0) diff --git a/test/sub_transport/test_reconnect.py b/test/sub_transport/test_reconnect.py index e3e203783..52d442293 100644 --- a/test/sub_transport/test_reconnect.py +++ b/test/sub_transport/test_reconnect.py @@ -16,7 +16,7 @@ async def test_no_reconnect_call(self, client): assert not client.reconnect_delay_current client.transport_close() - async def test_reconnect_call(self, client, commparams): + async def test_reconnect_call(self, client, use_clc): """Test connection_lost().""" client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) await client.transport_connect() @@ -26,10 +26,10 @@ async def test_reconnect_call(self, client, commparams): await asyncio.sleep(client.reconnect_delay_current * 1.8) assert client.reconnect_task assert client.loop.create_connection.call_count == 2 - assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + assert client.reconnect_delay_current == use_clc.reconnect_delay * 2 client.transport_close() - async def test_multi_reconnect_call(self, client, commparams): + async def test_multi_reconnect_call(self, client, use_clc): """Test connection_lost().""" client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) await client.transport_connect() @@ -37,16 +37,16 @@ async def test_multi_reconnect_call(self, client, commparams): client.connection_lost(RuntimeError("Connection lost")) await asyncio.sleep(client.reconnect_delay_current * 1.8) assert client.loop.create_connection.call_count == 2 - assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + assert client.reconnect_delay_current == use_clc.reconnect_delay * 2 await asyncio.sleep(client.reconnect_delay_current * 1.8) assert client.loop.create_connection.call_count == 3 - assert client.reconnect_delay_current == commparams.reconnect_delay_max + assert client.reconnect_delay_current == use_clc.reconnect_delay_max await asyncio.sleep(client.reconnect_delay_current * 1.8) assert client.loop.create_connection.call_count >= 4 - assert client.reconnect_delay_current == commparams.reconnect_delay_max + assert client.reconnect_delay_current == use_clc.reconnect_delay_max client.transport_close() - async def test_reconnect_call_ok(self, client, commparams): + async def test_reconnect_call_ok(self, client, use_clc): """Test connection_lost().""" client.loop.create_connection = mock.AsyncMock( return_value=(mock.Mock(), mock.Mock()) @@ -56,6 +56,6 @@ async def test_reconnect_call_ok(self, client, commparams): client.connection_lost(RuntimeError("Connection lost")) await asyncio.sleep(client.reconnect_delay_current * 1.8) assert client.loop.create_connection.call_count == 2 - assert client.reconnect_delay_current == commparams.reconnect_delay + assert client.reconnect_delay_current == use_clc.reconnect_delay assert not client.reconnect_task client.transport_close() diff --git a/test/test_client_sync.py b/test/test_client_sync.py index cca28f252..f5bd762e9 100755 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -95,7 +95,7 @@ def test_udp_client_repr(self): client = ModbusUdpClient("127.0.0.1") rep = ( f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, " - f"ipaddr={client.params.host}, port={client.params.port}, timeout={client.params.timeout}>" + f"ipaddr={client.comm_params.host}, port={client.comm_params.port}, timeout={client.comm_params.timeout_connect}>" ) assert repr(client) == rep @@ -160,7 +160,7 @@ def test_tcp_client_recv(self, mock_select, mock_time): mock_socket = mock.MagicMock() mock_socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"]) client.socket = mock_socket - client.params.timeout = 3 + client.comm_params.timeout_connect = 3 assert client.recv(3) == b"\x00\x01\x02" mock_socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"]) assert client.recv(2) == b"\x00\x01" @@ -185,7 +185,7 @@ def test_tcp_client_repr(self): client = ModbusTcpClient("127.0.0.1") rep = ( f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, " - f"ipaddr={client.params.host}, port={client.params.port}, timeout={client.params.timeout}>" + f"ipaddr={client.comm_params.host}, port={client.comm_params.port}, timeout={client.comm_params.timeout_connect}>" ) assert repr(client) == rep @@ -264,7 +264,7 @@ def test_tls_client_recv(self, mock_select, mock_time): assert client.recv(0) == b"" assert client.recv(4) == b"\x00" * 4 - client.params.timeout = 2 + client.comm_params.timeout_connect = 2 client.socket.mock_prepare_receive(b"\x00") assert b"\x00" in client.recv(None) @@ -273,8 +273,8 @@ def test_tls_client_repr(self): client = ModbusTlsClient("127.0.0.1") rep = ( f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, " - f"ipaddr={client.params.host}, port={client.params.port}, sslctx={client.sslctx}, " - f"timeout={client.params.timeout}>" + f"ipaddr={client.comm_params.host}, port={client.comm_params.port}, sslctx={client.sslctx}, " + f"timeout={client.comm_params.timeout_connect}>" ) assert repr(client) == rep @@ -424,6 +424,6 @@ def test_serial_client_repr(self): client = ModbusSerialClient("/dev/null") rep = ( f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, " - f"framer={client.framer}, timeout={client.params.timeout}>" + f"framer={client.framer}, timeout={client.comm_params.timeout_connect}>" ) assert repr(client) == rep diff --git a/test/test_framers.py b/test/test_framers.py index bc645d398..baee88749 100644 --- a/test/test_framers.py +++ b/test/test_framers.py @@ -316,7 +316,7 @@ def test_send_packet(rtu_framer): client.state = ModbusTransactionState.TRANSACTION_COMPLETE client.silent_interval = 1 client.last_frame_end = 1 - client.params.timeout = 0.25 + client.comm_params.timeout_connect = 0.25 client.idle_time = mock.Mock(return_value=1) client.send = mock.Mock(return_value=len(message)) rtu_framer.client = client From f11b50cf457f5db90db18a35c6ea17d0870ebfe3 Mon Sep 17 00:00:00 2001 From: jan iversen Date: Wed, 5 Jul 2023 17:32:52 +0200 Subject: [PATCH 2/4] please python 3.8 --- pymodbus/client/base.py | 5 +- pymodbus/client/serial.py | 4 -- pymodbus/server/async_io.py | 102 +++++++++---------------------- pymodbus/transport/transport.py | 76 +++++++++++++---------- test/sub_transport/test_basic.py | 34 ++++++++++- 5 files changed, 108 insertions(+), 113 deletions(-) diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index dbf0976e1..8d4d77f12 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -59,8 +59,6 @@ class _params: broadcast_enable: bool = None reconnect_delay: int = None - handle_local_echo: bool = None - source_address: tuple[str, int] = None server_hostname: str = None @@ -85,7 +83,7 @@ def __init__( # pylint: disable=too-many-arguments setup_params = CommParams( comm_type=kwargs.get("CommType"), comm_name="comm", - source_address=kwargs.get("source_address", ("localhost", 0)), + source_address=kwargs.get("source_address", ("127.0.0.1", 0)), reconnect_delay=reconnect_delay, reconnect_delay_max=reconnect_delay_max, timeout_connect=timeout, @@ -96,6 +94,7 @@ def __init__( # pylint: disable=too-many-arguments bytesize=kwargs.get("bytesize", None), parity=kwargs.get("parity", None), stopbits=kwargs.get("stopbits", None), + handle_local_echo=kwargs.get("handle_local_echo", False), ) if not self.use_sync: ModbusProtocol.__init__( diff --git a/pymodbus/client/serial.py b/pymodbus/client/serial.py index 50f7597da..addb5bdda 100644 --- a/pymodbus/client/serial.py +++ b/pymodbus/client/serial.py @@ -52,7 +52,6 @@ def __init__( bytesize: int = 8, parity: str = "N", stopbits: int = 1, - handle_local_echo: bool = False, **kwargs: Any, ) -> None: """Initialize Asyncio Modbus Serial Client.""" @@ -68,7 +67,6 @@ def __init__( stopbits=stopbits, **kwargs, ) - self.params.handle_local_echo = handle_local_echo @property def connected(self): @@ -127,7 +125,6 @@ def __init__( bytesize: int = 8, parity: str = "N", stopbits: int = 1, - handle_local_echo: bool = False, **kwargs: Any, ) -> None: """Initialize Modbus Serial Client.""" @@ -144,7 +141,6 @@ def __init__( stopbits=stopbits, **kwargs, ) - self.params.handle_local_echo = handle_local_echo self.socket = None self.last_frame_end = None diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 303361397..def7898a9 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -47,6 +47,7 @@ def __init__(self, owner): """Initialize.""" params = CommParams( comm_name="server", + comm_type=owner.comm_params.comm_type, reconnect_delay=0.0, reconnect_delay_max=0.0, timeout_connect=0.0, @@ -58,7 +59,6 @@ def __init__(self, owner): self.running = False self.receive_queue = asyncio.Queue() self.handler_task = None # coroutine to be run on asyncio loop - self._sent = b"" # for handle_local_echo self.client_address = (None, None) self.framer: ModbusFramer = None @@ -69,22 +69,6 @@ def _log_exception(self): def callback_connected(self) -> None: """Call when connection is succcesfull.""" try: - if ( - hasattr(self.transport, "get_extra_info") - and self.transport.get_extra_info("peername") is not None - ): - self.client_address = self.transport.get_extra_info("peername")[:2] - Log.debug("Peer [{}] opened", self.client_address) - elif hasattr(self.transport, "serial"): - Log.debug( - "Serial connection opened on port: {}", self.transport.serial.port - ) - self.client_address = ("serial", "server") - else: - Log.warning( - "Unable to get information about transport {}", self.transport - ) - self.transport = self.transport self.running = True self.framer = self.server.framer( self.server.decoder, @@ -189,7 +173,7 @@ async def handle(self): exc, client_addr, ) - self.transport.close() + self.transport_close() else: Log.error("Unknown error occurred {}", exc) reset_frame = True # graceful recovery @@ -239,23 +223,13 @@ def execute(self, request, *addr): response, skip_encoding = self.server.response_manipulator(response) self.send(response, *addr, skip_encoding=skip_encoding) - def send(self, message, *addr, **kwargs): + def send(self, message, addr, **kwargs): """Send message.""" - - def __send(msg, *addr): - Log.debug("send: [{}]- {}", message, msg, ":b2a") - if addr == (None,): - self.transport.write(msg) - if self.server.handle_local_echo is True: - self._sent = msg - else: - self.transport.sendto(msg, *addr) - if kwargs.get("skip_encoding", False): - __send(message, *addr) + self.transport_send(message, addr=addr) elif message.should_respond: pdu = self.framer.buildPacket(message) - __send(pdu, *addr) + self.transport_send(pdu, addr=addr) else: Log.debug("Skipping sending response!!") @@ -270,15 +244,6 @@ async def _recv_(self): # pragma: no cover def callback_data(self, data: bytes, addr: tuple = None) -> int: """Handle received data.""" - if self.server.handle_local_echo is True and self._sent: - if self._sent in data: - data, self._sent = data.replace(self._sent, b"", 1), b"" - elif self._sent.startswith(data): - self._sent, data = self._sent.replace(data, b"", 1), b"" - else: - self._sent = b"" - if not data: - return 0 if addr: self.receive_queue.put_nowait((data, addr)) else: @@ -369,21 +334,21 @@ def handle_new_connection(self): async def serve_forever(self): """Start endless loop.""" - if self.transport is None: - await self.transport_listen() - self.serving.set_result(True) - Log.info("Server(TCP) listening.") - try: - await self.transport.serve_forever() - except asyncio.exceptions.CancelledError: - self.serving_done.set_result(False) - raise - except Exception as exc: # pylint: disable=broad-except - Log.error("Server unexpected exception {}", exc) - else: + if self.transport: raise RuntimeError( "Can't call serve_forever on an already running server object" ) + + await self.transport_listen() + self.serving.set_result(True) + Log.info("Server(TCP) listening.") + try: + await self.transport.serve_forever() + except asyncio.exceptions.CancelledError: + self.serving_done.set_result(False) + raise + except Exception as exc: # pylint: disable=broad-except + Log.error("Server unexpected exception {}", exc) self.serving_done.set_result(True) Log.info("Server graceful shutdown.") @@ -548,23 +513,22 @@ def handle_new_connection(self): async def serve_forever(self): """Start endless loop.""" - if self.transport is None: - try: - await self.transport_listen() - except asyncio.exceptions.CancelledError: - self.serving_done.set_result(False) - raise - except Exception as exc: - Log.error("Server unexpected exception {}", exc) - self.serving_done.set_result(False) - raise RuntimeError(exc) from exc - Log.info("Server(UDP) listening.") - self.serving.set_result(True) - await self.stop_serving - else: + if self.transport: raise RuntimeError( "Can't call serve_forever on an already running server object" ) + try: + await self.transport_listen() + except asyncio.exceptions.CancelledError: + self.serving_done.set_result(False) + raise + except Exception as exc: + Log.error("Server unexpected exception {}", exc) + self.serving_done.set_result(False) + raise RuntimeError(exc) from exc + Log.info("Server(UDP) listening.") + self.serving.set_result(True) + await self.stop_serving self.serving_done.set_result(True) async def shutdown(self): @@ -653,12 +617,6 @@ def handle_new_connection(self): """Handle incoming connect.""" return ModbusServerRequestHandler(self) - def on_connection_lost(self): - """Call on lost connection.""" - if self.transport is not None: - self.transport.close() - self.transport = None - async def shutdown(self): """Terminate server.""" self.transport_close() diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index 50cfc076b..0ac3e8445 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -35,9 +35,10 @@ class CommParams: reconnect_delay: float = None reconnect_delay_max: float = None timeout_connect: float = None - host: str = "localhost" + host: str = "127.0.0.1" port: int = 0 - source_address: tuple[str, int] = ("localhost", 0) + source_address: tuple[str, int] = ("127.0.0.1", 0) + handle_local_echo: bool = False # tls sslctx: ssl.SSLContext = None @@ -93,7 +94,7 @@ class ModbusProtocol(asyncio.BaseProtocol): Host/Port/SourceAddress explanation: - SourceAddress: - - server: (host, port) to listen on (default is ("localhost", 502/802)) + - server: (host, port) to listen on (default is ("127.0.0.1", 502/802)) - server serial: (host, _) to open/connect and listen on - client: (Bind local part to interface (default is local interface) - client serial: (host, _) to open/connect and listen on @@ -134,13 +135,43 @@ def __init__( self.unique_id: str = str(id(self)) self.reconnect_task: asyncio.Task = None self.reconnect_delay_current: float = 0.0 + self.sent_buffer: bytes = b"" # ModbusProtocol specific setup if self.comm_params.comm_type == CommType.SERIAL: self.init_correct_serial() if self.init_check_nullmodem(): return + self.init_setup_connect_listen() + def init_correct_serial(self) -> None: + """Split host for serial if needed.""" + if self.is_server: + host = self.comm_params.source_address[0] + if host.startswith("socket"): + parts = host[9:].split(":") + self.comm_params.source_address = (parts[0], int(parts[1])) + self.comm_params.comm_type = CommType.TCP + elif host.startswith(NULLMODEM_HOST): + self.comm_params.source_address = (host, int(host[9:].split(":")[1])) + return + if self.comm_params.host.startswith(NULLMODEM_HOST): + self.comm_params.port = int(self.comm_params.host[9:].split(":")[1]) + + def init_check_nullmodem(self) -> bool: + """Check if nullmodem is needed.""" + if self.comm_params.host.startswith(NULLMODEM_HOST): + port = self.comm_params.port + elif self.comm_params.source_address[0].startswith(NULLMODEM_HOST): + port = self.comm_params.source_address[1] + else: + return False + + self.call_create = lambda: self.create_nullmodem(port) + return True + + def init_setup_connect_listen(self) -> None: + """Handle connect/listen handler.""" if self.comm_params.comm_type == CommType.SERIAL: self.call_create = lambda: create_serial_connection( self.loop, @@ -154,7 +185,7 @@ def __init__( ) return if self.comm_params.comm_type == CommType.UDP: - if is_server: + if self.is_server: self.call_create = lambda: self.loop.create_datagram_endpoint( self.handle_new_connection, local_addr=self.comm_params.source_address, @@ -162,12 +193,11 @@ def __init__( else: self.call_create = lambda: self.loop.create_datagram_endpoint( self.handle_new_connection, - local_addr=self.comm_params.source_address, remote_addr=(self.comm_params.host, self.comm_params.port), ) return # TLS and TCP - if is_server: + if self.is_server: self.call_create = lambda: self.loop.create_server( self.handle_new_connection, self.comm_params.source_address[0], @@ -185,32 +215,6 @@ def __init__( ssl=self.comm_params.sslctx, ) - def init_correct_serial(self) -> None: - """Split host for serial if needed.""" - if self.is_server: - host = self.comm_params.source_address[0] - if host.startswith("socket"): - parts = host[9:].split(":") - self.comm_params.source_address = (parts[0], int(parts[1])) - self.comm_params.comm_type = CommType.TCP - elif host.startswith(NULLMODEM_HOST): - self.comm_params.source_address = (host, int(host[9:].split(":")[1])) - return - if self.comm_params.host.startswith(NULLMODEM_HOST): - self.comm_params.port = int(self.comm_params.host[9:].split(":")[1]) - - def init_check_nullmodem(self) -> bool: - """Check if nullmodem is needed.""" - if self.comm_params.host.startswith(NULLMODEM_HOST): - port = self.comm_params.port - elif self.comm_params.source_address[0].startswith(NULLMODEM_HOST): - port = self.comm_params.source_address[1] - else: - return False - - self.call_create = lambda: self.create_nullmodem(port) - return True - async def transport_connect(self) -> bool: """Handle generic connect and call on to specific transport connect.""" Log.debug("Connecting {}", self.comm_params.comm_name) @@ -279,6 +283,9 @@ def data_received(self, data: bytes): :param data: non-empty bytes object with incoming data. """ Log.debug("recv: {}", data, ":hex") + if self.comm_params.handle_local_echo and self.sent_buffer == data: + self.sent_buffer = b"" + return self.recv_buffer += data cut = self.callback_data(self.recv_buffer) self.recv_buffer = self.recv_buffer[cut:] @@ -286,6 +293,9 @@ def data_received(self, data: bytes): def datagram_received(self, data: bytes, addr: tuple): """Receive datagram (UDP connections).""" Log.debug("recv: {} addr={}", data, ":hex", addr) + if self.comm_params.handle_local_echo and self.sent_buffer == data: + self.sent_buffer = b"" + return self.recv_buffer += data cut = self.callback_data(self.recv_buffer, addr=addr) self.recv_buffer = self.recv_buffer[cut:] @@ -325,6 +335,8 @@ def transport_send(self, data: bytes, addr: tuple = None) -> None: :param addr: optional addr, only used for UDP server. """ Log.debug("send: {}", data, ":hex") + if self.comm_params.handle_local_echo: + self.sent_buffer = data if self.comm_params.comm_type == CommType.UDP: if addr: self.transport.sendto(data, addr=addr) diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py index d5998c464..e39ef1c43 100644 --- a/test/sub_transport/test_basic.py +++ b/test/sub_transport/test_basic.py @@ -136,10 +136,27 @@ async def test_transport_send(self, client): client.transport_send(b"abc") client.transport_send(b"abc", addr=("localhost", 502)) + async def test_handle_local_echo(self, client): + """Test transport_send().""" + client.comm_params.handle_local_echo = True + client.transport = mock.Mock() + test_data = b"abc" + client.transport_send(test_data) + client.data_received(test_data) + assert not client.recv_buffer + client.data_received(test_data) + assert client.recv_buffer == test_data + client.recv_buffer = b"" + client.transport_send(test_data) + client.datagram_received(test_data, ("127.0.0.1", 502)) + assert not client.recv_buffer + client.datagram_received(test_data, ("127.0.0.1", 502)) + assert client.recv_buffer == test_data + async def test_transport_close(self, server, dummy_protocol): """Test transport_close().""" - dummy_protocol.abort = mock.Mock() - dummy_protocol.close = mock.Mock() + dummy_protocol.abort = mock.MagicMock() + dummy_protocol.close = mock.MagicMock() server.connection_made(dummy_protocol) server.recv_buffer = b"abc" server.reconnect_task = mock.MagicMock() @@ -150,6 +167,19 @@ async def test_transport_close(self, server, dummy_protocol): await server.transport_listen() server.active_connections = {"a": dummy_protocol} server.transport_close() + server.transport_close() + assert not server.active_connections + + async def test_transport_close2(self, server, client, dummy_protocol): + """Test transport_close().""" + dummy_protocol.abort = mock.Mock() + dummy_protocol.close = mock.Mock() + client.connection_made(dummy_protocol) + client.recv_buffer = b"abc" + client.reconnect_task = mock.MagicMock() + client.listener = server + server.active_connections = {client.unique_id: dummy_protocol} + client.transport_close() assert not server.active_connections async def test_reset_delay(self, client, use_clc): From 6ec8a3f09d945b00861710bb5df65d7af2ccf7be Mon Sep 17 00:00:00 2001 From: jan iversen Date: Wed, 5 Jul 2023 20:12:56 +0200 Subject: [PATCH 3/4] async_io.py --- pymodbus/server/async_io.py | 37 +++++++------------------------------ test/test_server_asyncio.py | 2 -- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index def7898a9..933492a21 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -59,12 +59,13 @@ def __init__(self, owner): self.running = False self.receive_queue = asyncio.Queue() self.handler_task = None # coroutine to be run on asyncio loop - self.client_address = (None, None) self.framer: ModbusFramer = None def _log_exception(self): """Show log exception.""" - Log.debug("Handler for stream [{}] has been canceled", self.client_address) + Log.debug( + "Handler for stream [{}] has been canceled", self.comm_params.comm_name + ) def callback_connected(self) -> None: """Call when connection is succcesfull.""" @@ -74,7 +75,6 @@ def callback_connected(self) -> None: self.server.decoder, client=None, ) - self.server.local_active_connections[self.client_address] = self # schedule the connection handler on the event loop self.handler_task = asyncio.create_task(self.handle()) @@ -90,15 +90,15 @@ def callback_disconnected(self, call_exc: Exception) -> None: try: if self.handler_task: self.handler_task.cancel() - if self.client_address in self.server.local_active_connections: - self.server.local_active_connections.pop(self.client_address) if hasattr(self.server, "on_connection_lost"): self.server.on_connection_lost() if call_exc is None: self._log_exception() else: Log.debug( - "Client Disconnection {} due to {}", self.client_address, call_exc + "Client Disconnection {} due to {}", + self.comm_params.comm_name, + call_exc, ) self.running = False except Exception as exc: # pylint: disable=broad-except @@ -167,11 +167,10 @@ async def handle(self): # should handle application layer errors # for UDP sockets, simply reset the frame if isinstance(self, ModbusServerRequestHandler): - client_addr = self.client_address[:2] Log.error( 'Unknown exception "{}" on stream {} forcing disconnect', exc, - client_addr, + self.comm_params.comm_name, ) self.transport_close() else: @@ -309,7 +308,6 @@ def __init__( params, True, ) - self.local_active_connections = {} self.decoder = ServerDecoder() self.framer = framer or ModbusSocketFramer self.context = context or ModbusServerContext() @@ -358,16 +356,6 @@ async def shutdown(self): async def server_close(self): """Close server.""" - active_connecions = self.local_active_connections.copy() - for k_item, v_item in active_connecions.items(): - Log.warning("aborting active session {}", k_item) - if v_item.transport: - v_item.transport.close() - await asyncio.sleep(0.1) - if v_item.handler_task: - v_item.handler_task.cancel() - await v_item.handler_task - self.local_active_connections = {} self.transport_close() @@ -488,7 +476,6 @@ def __init__( True, ) - self.local_active_connections = {} self.loop = asyncio.get_running_loop() self.decoder = ServerDecoder() self.framer = framer or ModbusSocketFramer @@ -602,7 +589,6 @@ def __init__( self.control = ModbusControlBlock() if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) - self.local_active_connections = {} self.request_tracer = None self.server = None self.control = ModbusControlBlock() @@ -620,15 +606,6 @@ def handle_new_connection(self): async def shutdown(self): """Terminate server.""" self.transport_close() - loop_list = list(self.local_active_connections) - for k_item in loop_list: - v_item = self.local_active_connections[k_item] - Log.warning("aborting active session {}", k_item) - v_item.transport.close() - await asyncio.sleep(0.1) - v_item.handler_task.cancel() - await v_item.handler_task - self.local_active_connections = {} if self.server: self.server.close() await asyncio.wait_for(self.server.wait_closed(), 10) diff --git a/test/test_server_asyncio.py b/test/test_server_asyncio.py index ec3f573cd..7cbaf823a 100755 --- a/test/test_server_asyncio.py +++ b/test/test_server_asyncio.py @@ -229,11 +229,9 @@ async def test_async_tcp_server_connection_lost(self): """Test tcp stream interruption""" await self.start_server() await self.connect_server() - assert len(self.server.local_active_connections), 1 BasicClient.transport.close() await asyncio.sleep(0.2) # so we have to wait a bit - assert not self.server.local_active_connections async def test_async_tcp_server_close_connection(self): """Test server_close() while there are active TCP connections""" From 21fb9c780120bc7b89ee60ec3c8380a1c23d0b44 Mon Sep 17 00:00:00 2001 From: jan iversen Date: Wed, 5 Jul 2023 20:35:46 +0200 Subject: [PATCH 4/4] last touch. --- pymodbus/server/async_io.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 933492a21..db83c1d81 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -36,11 +36,9 @@ class ModbusServerRequestHandler(ModbusProtocol): This uses the asyncio.Protocol to implement the server protocol. - When a connection is established, the asyncio.Protocol.connection_made - callback is called. This callback will setup the connection and + When a connection is established, a callback is called. + This callback will setup the connection and create and schedule an asyncio.Task and assign it to running_task. - - running_task will be canceled upon connection_lost event. """ def __init__(self, owner): @@ -80,7 +78,7 @@ def callback_connected(self) -> None: self.handler_task = asyncio.create_task(self.handle()) except Exception as exc: # pragma: no cover pylint: disable=broad-except Log.error( - "Server connection_made unable to fulfill request: {}; {}", + "Server callback_connected exception: {}; {}", exc, traceback.format_exc(), ) @@ -145,8 +143,6 @@ async def handle(self): the ModbusServerRequestHandler class's callback Future. This callback future gets data from either - asyncio.DatagramProtocol.datagram_received or - from asyncio.BaseProtocol.data_received. This function will execute without blocking in the while-loop and yield to the asyncio event loop when the frame is exhausted.