diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 71c2ce71e7..83da04c351 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -11,6 +11,7 @@ from pymodbus.device import ModbusControlBlock, ModbusDeviceIdentification from pymodbus.exceptions import NoSuchSlaveException from pymodbus.factory import ServerDecoder +from pymodbus.framer import ModbusFramer from pymodbus.logging import Log from pymodbus.pdu import ModbusExceptions as merror from pymodbus.transaction import ( @@ -20,6 +21,7 @@ ModbusTlsFramer, ) from pymodbus.transport.serial_asyncio import create_serial_connection +from pymodbus.transport.transport import CommParams, Transport with suppress(ImportError): @@ -64,7 +66,7 @@ def sslctx_provider( # --------------------------------------------------------------------------- # -class ModbusServerRequestHandler(asyncio.BaseProtocol): +class ModbusServerRequestHandler(Transport): """Implements modbus slave wire protocol. This uses the asyncio.Protocol to implement the server protocol. @@ -78,12 +80,20 @@ class ModbusServerRequestHandler(asyncio.BaseProtocol): def __init__(self, owner): """Initialize.""" + params = CommParams( + comm_name="server", + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + ) + super().__init__(params, True) self.server = owner self.running = False self.receive_queue = asyncio.Queue() self.handler_task = None # coroutine to be run on asyncio loop self._sent = b"" # for handle_local_echo self.client_address = (None, None) + self.framer: ModbusFramer = None def _log_exception(self): """Show log exception.""" @@ -108,13 +118,11 @@ def connection_made(self, transport): self.client_address = ("serial", "server") else: Log.warning("Unable to get information about transport {}", transport) - self.transport = transport # pylint: disable=attribute-defined-outside-init + self.transport = transport self.running = True - self.framer = ( # pylint: disable=attribute-defined-outside-init - self.server.framer( - self.server.decoder, - client=None, - ) + self.framer = self.server.framer( + self.server.decoder, + client=None, ) self.server.active_connections[self.client_address] = self @@ -267,13 +275,15 @@ def execute(self, request, *addr): response, skip_encoding = self.server.response_manipulator(response) self.send(response, *addr, skip_encoding=skip_encoding) - def send(self, message, *addr, **kwargs): + def send(self, message, *addr, **kwargs): # pylint: disable=arguments-differ """Send message.""" def __send(msg, *addr): Log.debug("send: [{}]- {}", message, msg, ":b2a") if addr == (None,): - self._send_(msg) + self.transport.write(msg) + if self.server.handle_local_echo is True: + self._sent = msg else: self.transport.sendto(msg, *addr) @@ -286,20 +296,6 @@ def __send(msg, *addr): else: Log.debug("Skipping sending response!!") - # ----------------------------------------------------------------------- # - # Derived class implementations - # ----------------------------------------------------------------------- # - - def _send_(self, data): # pragma: no cover - """Send a request (string) to the network. - - :param data: The unencoded modbus response - :raises NotImplementedException: - """ - self.transport.write(data) - if self.server.handle_local_echo is True: - self._sent = data - async def _recv_(self): # pragma: no cover """Receive data from the network.""" try: @@ -390,8 +386,6 @@ def __init__( self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.path = path - self.handler = ModbusServerRequestHandler - self.handler.server = self self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) self.broadcast_enable = kwargs.get("broadcast_enable", False) self.response_manipulator = kwargs.get("response_manipulator", None) @@ -412,8 +406,10 @@ async def serve_forever(self): """Start endless loop.""" if self.server is None: try: + handler = ModbusServerRequestHandler + handler.server = self self.server = await self.loop.create_unix_server( - lambda: self.handler(self), + lambda: handler(self), self.path, ) self.serving.set_result(True) @@ -494,8 +490,6 @@ def __init__( self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.address = address or ("", 502) - self.handler = ModbusServerRequestHandler - self.handler.server = self self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) self.broadcast_enable = kwargs.get("broadcast_enable", False) self.response_manipulator = kwargs.get("response_manipulator", None) @@ -519,8 +513,10 @@ def __init__( async def serve_forever(self): """Start endless loop.""" if self.server is None: + handler = ModbusServerRequestHandler + handler.server = self self.server = await self.loop.create_server( - lambda: self.handler(self), + lambda: handler(self), *self.address, **self.factory_parms, ) @@ -669,7 +665,6 @@ def __init__( self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.address = address or ("", 502) - self.handler = ModbusServerRequestHandler self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) self.broadcast_enable = kwargs.get("broadcast_enable", False) self.response_manipulator = kwargs.get("response_manipulator", None) @@ -694,8 +689,10 @@ async def serve_forever(self): """Start endless loop.""" if self.protocol is None: try: + handler = ModbusServerRequestHandler + handler.server = self self.protocol, self.endpoint = await self.loop.create_datagram_endpoint( - lambda: self.handler(self), + lambda: handler(self), **self.factory_parms, ) except asyncio.exceptions.CancelledError: @@ -782,7 +779,6 @@ def __init__( self.auto_reconnect = kwargs.get("auto_reconnect", False) self.reconnect_delay = kwargs.get("reconnect_delay", 2) self.reconnecting_task = None - self.handler = kwargs.get("handler") or ModbusServerRequestHandler self.framer = framer or ModbusRtuFramer self.decoder = ServerDecoder() self.context = context or ModbusServerContext() @@ -818,9 +814,11 @@ async def _connect(self): if self.device.startswith("socket:"): return try: + handler = ModbusServerRequestHandler + handler.server = self self.transport, self.protocol = await create_serial_connection( self.loop, - lambda: self.handler(self), + lambda: handler(self), self.device, baudrate=self.baudrate, bytesize=self.bytesize, @@ -889,8 +887,10 @@ async def serve_forever(self): # Socket server means listen so start a socket server parts = self.device[9:].split(":") host_addr = (parts[0], int(parts[1])) + handler = ModbusServerRequestHandler + handler.server = self self.server = await self.loop.create_server( - lambda: self.handler(self), + lambda: handler(self), *host_addr, reuse_address=True, start_serving=True, diff --git a/pymodbus/transport/nullmodem.py b/pymodbus/transport/nullmodem.py deleted file mode 100644 index ea86857362..0000000000 --- a/pymodbus/transport/nullmodem.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Null modem transport. - -This is a special transport, mostly thought of for testing. - -NullModem interconnect 2 transport objects and transfers calls: - - server.listen() - - dummy - - client.connect() - - call client.connection_made() - - call server.connection_made() - - client/server.close() - - call client.connection_lost() - - call server.connection_lost() - - server/client.send - - call client/server.data_received() - -""" -from __future__ import annotations - -import asyncio - -from pymodbus.logging import Log -from pymodbus.transport.transport import Transport - - -class DummyTransport(asyncio.BaseTransport): - """Use in connection_made calls.""" - - def close(self): - """Define dummy.""" - - def get_protocol(self): - """Define dummy.""" - - def is_closing(self): - """Define dummy.""" - - def set_protocol(self, _protocol): - """Define dummy.""" - - def abort(self): - """Define dummy.""" - - -class NullModem(Transport): - """Transport layer. - - Contains methods to act as a null modem between 2 objects. - (Allowing tests to be shortcut without actual network calls) - """ - - nullmodem_client: NullModem = None - nullmodem_server: NullModem = None - - def __init__(self, *arg): - """Overwrite init.""" - self.is_server: bool = False - self.other_end: NullModem = None - super().__init__(*arg) - - async def transport_connect(self) -> bool: - """Handle generic connect and call on to specific transport connect.""" - Log.debug("NullModem: Simulate connect on {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() - if self.nullmodem_server: - self.__class__.nullmodem_client = self - self.other_end = self.nullmodem_server - self.nullmodem_server.other_end = self - self.cb_connection_made() - self.other_end.cb_connection_made() - return True - return False - - async def transport_listen(self): - """Handle generic listen and call on to specific transport listen.""" - Log.debug("NullModem: Simulate listen on {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() - self.is_server = True - self.__class__.nullmodem_server = self - return DummyTransport() - - # -------------------------------- # - # Helper methods for child classes # - # -------------------------------- # - async def send(self, data: bytes) -> bool: - """Send request. - - :param data: non-empty bytes object with data to send. - """ - Log.debug("NullModem: simulate send {}", data, ":hex") - self.other_end.data_received(data) - return True - - def close(self, reconnect: bool = False) -> None: - """Close connection. - - :param reconnect: (default false), try to reconnect - """ - self.recv_buffer = b"" - if not reconnect: - if self.nullmodem_client: - self.nullmodem_client.cb_connection_lost(None) - if self.nullmodem_server: - self.nullmodem_server.cb_connection_lost(None) - self.__class__.nullmodem_client = None - self.__class__.nullmodem_server = None - - # ----------------- # - # The magic methods # - # ----------------- # - def __str__(self) -> str: - """Build a string representation of the connection.""" - return f"{self.__class__.__name__}({self.comm_params.comm_name})" diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index 89751fd06f..78ed6432d8 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -4,242 +4,125 @@ from __future__ import annotations import asyncio +import dataclasses import ssl -import sys -from dataclasses import dataclass +from enum import Enum from typing import Any, Callable, Coroutine from pymodbus.logging import Log from pymodbus.transport.serial_asyncio import create_serial_connection -class Transport: - """Transport layer. +NULLMODEM_HOST = "__pymodbus_nullmodem" - Contains pure transport methods needed to connect/listen, send/receive and close connections - for unix socket, tcp, tls and serial communications. - Contains high level methods like reconnect. +class CommType(Enum): + """Type of transport""" - This class is not available in the pymodbus API, and should not be referenced in Applications - nor in the pymodbus documentation. + TCP = 1 + TLS = 2 + UDP = 3 + SERIAL = 4 - The class is designed to be an object in the message level class. - """ - @dataclass - class CommParamsClass: - """Parameter class.""" +@dataclasses.dataclass +class CommParams: + """Parameter class.""" + + # generic + comm_name: str = None + comm_type: CommType = None + reconnect_delay: float = None + reconnect_delay_max: float = None + timeout_connect: float = None + + # tcp / tls / udp / serial + host: str = None - # generic - done: bool = False - comm_name: str = None - reconnect_delay: float = None - reconnect_delay_max: float = None - timeout_connect: float = None + # tcp / tls / udp + port: int = None - # tcp / tls / udp / serial - host: str = None + # tls + sslctx: ssl.SSLContext = None - # tcp / tls / udp - port: int = None + # serial + baudrate: int = None + bytesize: int = None + parity: str = None + stopbits: int = None - # tls - ssl: ssl.SSLContext = None - server_hostname: str = None + def generate_ssl( + self, + is_server: bool, + certfile: str = None, + keyfile: str = None, + password: str = None, + ) -> None: + """Generate sslctx from cert/key/passwor + + MODBUS/TCP Security Protocol Specification demands TLSv2 at least + """ + self.sslctx = ssl.SSLContext( + ssl.PROTOCOL_TLS_SERVER if is_server else ssl.PROTOCOL_TLS_CLIENT + ) + self.sslctx.check_hostname = False + self.sslctx.verify_mode = ssl.CERT_NONE + self.sslctx.options |= ssl.OP_NO_TLSv1_1 + self.sslctx.options |= ssl.OP_NO_TLSv1 + self.sslctx.options |= ssl.OP_NO_SSLv3 + self.sslctx.options |= ssl.OP_NO_SSLv2 + if certfile: + self.sslctx.load_cert_chain( + certfile=certfile, keyfile=keyfile, password=password + ) - # serial - baudrate: int = None - bytesize: int = None - parity: str = None - stopbits: int = None - def check_done(self): - """Check if already setup""" - if self.done: - raise RuntimeError("Already setup!") - self.done = True +class Transport(asyncio.BaseProtocol): + """Protocol layer including transport. + + Contains pure transport methods needed to connect/listen, send/receive and close connections + for unix socket, tcp, tls and serial communications. + + Contains high level methods like reconnect. + + The class is designed to take care of differences between the different transport mediums, and + provide a neutral interface for the upper layers. + """ def __init__( self, - comm_name: str, - reconnect_delay: int, - reconnect_max: int, - timeout_connect: int, - callback_connected: Callable[[], None], - callback_disconnected: Callable[[Exception], None], - callback_data: Callable[[bytes], int], + params: CommParams, + is_server: bool, ) -> None: """Initialize a transport instance. - :param comm_name: name of this transport connection - :param reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect) - :param reconnect_delay: max reconnect delay in milliseconds - :param timeout_connect: Max. time in milliseconds for connect to complete + :param params: parameter dataclass + :param is_server: true if object act as a server (listen allowed) :param callback_connected: Called when connection is established :param callback_disconnected: Called when connection is disconnected :param callback_data: Called when data is received """ - self.cb_connection_made = callback_connected - self.cb_connection_lost = callback_disconnected - self.cb_handle_data = callback_data - - # properties, can be read, but may not be mingled with - self.comm_params = self.CommParamsClass( - comm_name=comm_name, - reconnect_delay=reconnect_delay / 1000, - reconnect_delay_max=reconnect_max / 1000, - timeout_connect=timeout_connect / 1000, - ) + self.comm_params = dataclasses.replace(params) + self.is_server = is_server self.reconnect_delay_current: float = 0.0 + self.listener: Transport = None self.transport: asyncio.BaseTransport | asyncio.Server = None self.loop: asyncio.AbstractEventLoop = None self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" - self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None - self.use_udp = False - - # ------------------------ # - # Transport specific setup # - # ------------------------ # - def setup_unix(self, setup_server: bool, host: str): - """Prepare transport unix""" - if sys.platform.startswith("win"): - raise RuntimeError("Modbus_unix is not supported on Windows!") - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - if setup_server: - self.call_connect_listen = lambda: self.loop.create_unix_server( - self.handle_listen, - path=self.comm_params.host, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_unix_connection( - lambda: self, - path=self.comm_params.host, - ) - - def setup_tcp(self, setup_server: bool, host: str, port: int): - """Prepare transport tcp""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - if setup_server: - self.call_connect_listen = lambda: self.loop.create_server( - self.handle_listen, - host=self.comm_params.host, - port=self.comm_params.port, - reuse_address=True, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_connection( - lambda: self, - host=self.comm_params.host, - port=self.comm_params.port, - ) - - def setup_tls( - self, - setup_server: bool, - host: str, - port: int, - sslctx: ssl.SSLContext, - certfile: str, - keyfile: str, - password: str, - server_hostname: str, - ): - """Prepare transport tls""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - self.comm_params.server_hostname = server_hostname - if not sslctx: - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - sslctx = ssl.SSLContext( - ssl.PROTOCOL_TLS_SERVER if setup_server else ssl.PROTOCOL_TLS_CLIENT - ) - sslctx.check_hostname = False - sslctx.verify_mode = ssl.CERT_NONE - sslctx.options |= ssl.OP_NO_TLSv1_1 - sslctx.options |= ssl.OP_NO_TLSv1 - sslctx.options |= ssl.OP_NO_SSLv3 - sslctx.options |= ssl.OP_NO_SSLv2 - if certfile: - sslctx.load_cert_chain( - certfile=certfile, keyfile=keyfile, password=password - ) - self.comm_params.ssl = sslctx - if setup_server: - self.call_connect_listen = lambda: self.loop.create_server( - self.handle_listen, - host=self.comm_params.host, - port=self.comm_params.port, - reuse_address=True, - ssl=self.comm_params.ssl, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_connection( - lambda: self, - self.comm_params.host, - self.comm_params.port, - ssl=self.comm_params.ssl, - server_hostname=self.comm_params.server_hostname, - ) - - def setup_udp(self, setup_server: bool, host: str, port: int): - """Prepare transport udp""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - if setup_server: - - async def call_async_listen(self): - """Remove protocol return value.""" - transport, _protocol = await self.loop.create_datagram_endpoint( - self.handle_listen, - local_addr=(self.comm_params.host, self.comm_params.port), - ) - return transport - - self.call_connect_listen = lambda: call_async_listen(self) - else: - self.call_connect_listen = lambda: self.loop.create_datagram_endpoint( - lambda: self, - remote_addr=(self.comm_params.host, self.comm_params.port), - ) - self.use_udp = True - - def setup_serial( - self, - setup_server: bool, - host: str, - baudrate: int, - bytesize: int, - parity: str, - stopbits: int, - ): - """Prepare transport serial""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.baudrate = baudrate - self.comm_params.bytesize = bytesize - self.comm_params.parity = parity - self.comm_params.stopbits = stopbits - if setup_server: - self.call_connect_listen = lambda: create_serial_connection( + self.call_create: Callable[[], Coroutine[Any, Any, Any]] = lambda: None + self.active_connections: dict[str, Transport] = {} + self.unique_id: str = str(id(self)) + + # Transport specific setup + if params.host == NULLMODEM_HOST: + self.call_create = self.create_nullmodem + return + if params.comm_type == CommType.SERIAL: + self.call_create = lambda: create_serial_connection( self.loop, - self.handle_listen, + self.handle_new_connection, self.comm_params.host, baudrate=self.comm_params.baudrate, bytesize=self.comm_params.bytesize, @@ -247,17 +130,35 @@ def setup_serial( stopbits=self.comm_params.stopbits, timeout=self.comm_params.timeout_connect, ) - + return + if params.comm_type == CommType.UDP: + if is_server: + self.call_create = lambda: self.loop.create_datagram_endpoint( + self.handle_new_connection, + local_addr=(self.comm_params.host, self.comm_params.port), + ) + else: + self.call_create = lambda: self.loop.create_datagram_endpoint( + self.handle_new_connection, + remote_addr=(self.comm_params.host, self.comm_params.port), + ) + return + # TLS and TCP + if is_server: + self.call_create = lambda: self.loop.create_server( + self.handle_new_connection, + self.comm_params.host, + self.comm_params.port, + ssl=self.comm_params.sslctx, + reuse_address=True, + start_serving=True, + ) else: - self.call_connect_listen = lambda: create_serial_connection( - self.loop, - lambda: self, + self.call_create = lambda: self.loop.create_connection( + self.handle_new_connection, self.comm_params.host, - baudrate=self.comm_params.baudrate, - bytesize=self.comm_params.bytesize, - stopbits=self.comm_params.stopbits, - parity=self.comm_params.parity, - timeout=self.comm_params.timeout_connect, + self.comm_params.port, + ssl=self.comm_params.sslctx, ) async def transport_connect(self) -> bool: @@ -265,10 +166,9 @@ async def transport_connect(self) -> bool: Log.debug("Connecting {}", self.comm_params.comm_name) if not self.loop: self.loop = asyncio.get_running_loop() - self.transport = None try: self.transport, _protocol = await asyncio.wait_for( - self.call_connect_listen(), + self.call_create(), timeout=self.comm_params.timeout_connect, ) except ( @@ -280,15 +180,18 @@ async def transport_connect(self) -> bool: return False return bool(self.transport) - async def transport_listen(self): + async def transport_listen(self) -> bool: """Handle generic listen and call on to specific transport listen.""" Log.debug("Awaiting connections {}", self.comm_params.comm_name) try: - self.transport = await self.call_connect_listen() + self.transport = await self.call_create() + if isinstance(self.transport, tuple): + self.transport = self.transport[0] except OSError as exc: Log.warning("Failed to start server {}", exc) self.close() - return self.transport + return False + return True # ---------------------------------- # # Transport asyncio standard methods # @@ -299,22 +202,22 @@ def connection_made(self, transport: asyncio.BaseTransport): :param transport: socket etc. representing the connection. """ Log.debug("Connected to {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() self.transport = transport self.reset_delay() - self.cb_connection_made() + self.callback_connected() def connection_lost(self, reason: Exception): """Call from asyncio, when the connection is lost or closed. :param reason: None or an exception object """ + if not self.transport: + return Log.debug("Connection lost {} due to {}", self.comm_params.comm_name, reason) - self.cb_connection_lost(reason) - if self.transport: - self.close() - self.reconnect_task = asyncio.create_task(self.reconnect_connect()) + self.close() + if not self.is_server: + self.reconnect_task = asyncio.create_task(self.do_reconnect()) + self.callback_disconnected(reason) def data_received(self, data: bytes): """Call when some data is received. @@ -323,7 +226,7 @@ def data_received(self, data: bytes): """ Log.debug("recv: {}", data, ":hex") self.recv_buffer += data - cut = self.cb_handle_data(self.recv_buffer) + cut = self.callback_data(self.recv_buffer) self.recv_buffer = self.recv_buffer[cut:] def datagram_received(self, data, _addr): @@ -331,32 +234,44 @@ def datagram_received(self, data, _addr): self.data_received(data) def eof_received(self): - """Accept other end terminates connection. - - Actual handling are in connection_lost() - """ + """Accept other end terminates connection.""" Log.debug("-> eof_received") def error_received(self, exc): - """Get error detected in UDP. - - Actual handling are in connection_lost() - """ + """Get error detected in UDP.""" Log.debug("-> error_received {}", exc) raise RuntimeError(str(exc)) + # --------- # + # callbacks # + # --------- # + def callback_connected(self) -> None: + """Call when connection is succcesfull.""" + + def callback_disconnected(self, _exc: Exception) -> None: + """Call when connection is lost.""" + + def callback_data(self, _data: bytes) -> int: + """Handle received data.""" + return 0 + # ----------------------------------- # # Helper methods for external classes # # ----------------------------------- # - async def send(self, data: bytes) -> bool: + def send(self, data: bytes, addr: tuple = None) -> None: """Send request. :param data: non-empty bytes object with data to send. + :param addr: optional addr, only used for UDP server. """ Log.debug("send: {}", data, ":hex") - if self.use_udp: - return self.transport.sendto(data) # type: ignore[union-attr] - return self.transport.write(data) # type: ignore[union-attr] + if self.comm_params.comm_type == CommType.UDP: + if addr: + self.transport.sendto(data, addr=addr) # type: ignore[union-attr] + else: + self.transport.sendto(data) # type: ignore[union-attr] + else: + self.transport.write(data) # type: ignore[union-attr] def close(self, reconnect: bool = False) -> None: """Close connection. @@ -371,7 +286,15 @@ def close(self, reconnect: bool = False) -> None: if not reconnect and self.reconnect_task: self.reconnect_task.cancel() self.reconnect_task = None - self.recv_buffer = b"" + self.reconnect_delay_current = 0.0 + self.recv_buffer = b"" + if self.listener: + self.listener.active_connections.pop(self.unique_id) + elif self.is_server: + for _key, value in self.active_connections.items(): + value.listener = None + value.close() + self.active_connections = {} def reset_delay(self) -> None: """Reset wait time before next reconnect to minimal period.""" @@ -384,11 +307,26 @@ def is_active(self) -> bool: # ---------------- # # Internal methods # # ---------------- # - def handle_listen(self): + async def create_nullmodem(self): + """Bypass create_ and use null modem""" + new_transport = NullModem(self.is_server, self) + new_protocol = self.handle_new_connection() + new_protocol.connection_made(new_transport) + if self.is_server: + return new_transport, new_protocol + return new_transport, new_protocol + + def handle_new_connection(self): """Handle incoming connect.""" - return self + if not self.is_server: + return self - async def reconnect_connect(self): + new_transport = Transport(self.comm_params, True) + new_transport.listener = self + self.active_connections[new_transport.unique_id] = new_transport + return new_transport + + async def do_reconnect(self): """Handle reconnect as a task.""" try: self.reconnect_delay_current = self.comm_params.reconnect_delay @@ -423,3 +361,82 @@ async def __aexit__(self, _class, _value, _traceback) -> None: def __str__(self) -> str: """Build a string representation of the connection.""" return f"{self.__class__.__name__}({self.comm_params.comm_name})" + + +class NullModem(asyncio.DatagramTransport, asyncio.WriteTransport): + """Transport layer. + + Contains methods to act as a null modem between 2 objects. + (Allowing tests to be shortcut without actual network calls) + """ + + client: NullModem = None + server: NullModem = None + client_protocol: Transport = None + server_protocol: Transport = None + + def __init__(self, is_server: bool, protocol: Transport): + """Create half part of null modem""" + asyncio.DatagramTransport.__init__(self) + asyncio.WriteTransport.__init__(self) + self.other_protocol: Transport = None + if is_server: + self.__class__.server = self + self.__class__.server_protocol = protocol + return + if not self.__class__.server: + raise RuntimeError("Connect called before listen") + self.__class__.client = self + self.__class__.client_protocol = protocol + self.client.other_protocol = self.server_protocol + self.server.other_protocol = self.client_protocol + + # ---------------- # + # external methods # + # ---------------- # + + def close(self): + """Close null modem""" + + def sendto(self, data: bytes, _addr: Any = None): + """Send datagrame""" + return self.write(data) + + def write(self, data: bytes): + """Send data""" + return len(data) + + # ---------------- # + # Abstract methods # + # ---------------- # + def abort(self) -> None: + """Abort connection.""" + + def can_write_eof(self) -> bool: + """Allow to write eof""" + return True + + def get_write_buffer_size(self) -> int: + """Set write limit.""" + return 1024 + + def get_write_buffer_limits(self) -> tuple[int, int]: + """Set flush limits""" + return (1, 1024) + + def set_write_buffer_limits(self, high: int = None, low: int = None) -> None: + """Set flush limits""" + + def write_eof(self) -> None: + """Write eof""" + + def get_protocol(self) -> Transport: + """Return current protocol.""" + return None + + def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: + """Set current protocol.""" + + def is_closing(self) -> bool: + """Return true if closing""" + return False diff --git a/test/sub_transport/conftest.py b/test/sub_transport/conftest.py index 29bfdbdc74..1f8052c66e 100644 --- a/test/sub_transport/conftest.py +++ b/test/sub_transport/conftest.py @@ -1,139 +1,126 @@ """Fixtures for transport tests.""" import asyncio +import dataclasses import os -import time from contextlib import suppress -from dataclasses import dataclass -from tempfile import gettempdir from unittest import mock import pytest -import pytest_asyncio -from pymodbus.transport.nullmodem import NullModem -from pymodbus.transport.transport import Transport +from pymodbus.transport.transport import CommParams, CommType, NullModem, Transport -@dataclass -class BaseParams(Transport.CommParamsClass): - """Base parameters for all transport testing.""" +class DummyTransport(asyncio.BaseTransport): + """Use in connection_made calls.""" - comm_name = "test comm" - reconnect_delay = 1000 - reconnect_delay_max = 3500 - timeout_connect = 2000 - host = "test host" - port = 502 - server_hostname = "server test host" - baudrate = 9600 - bytesize = 8 - parity = "e" - stopbits = 2 - cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + def close(self): + """Define dummy.""" + def get_protocol(self): + """Define dummy.""" -@pytest.fixture(name="params") -def prepare_baseparams(use_port): - """Prepare BaseParams class.""" - BaseParams.port = use_port - return BaseParams + def is_closing(self): + """Define dummy.""" + def set_protocol(self, _protocol): + """Define dummy.""" -@pytest.fixture(name="commparams") -def prepare_testparams(): - """Prepare CommParamsClass object.""" - return Transport.CommParamsClass( - done=True, - comm_name=BaseParams.comm_name, - reconnect_delay=BaseParams.reconnect_delay / 1000, - reconnect_delay_max=BaseParams.reconnect_delay_max / 1000, - timeout_connect=BaseParams.timeout_connect / 1000, - ) + def abort(self): + """Define dummy.""" -@pytest.fixture(name="transport") -async def prepare_transport(): - """Prepare transport object.""" - transport = Transport( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) - with suppress(RuntimeError): - transport.loop = asyncio.get_running_loop() - return transport +@pytest.fixture(name="dummy_transport") +def prepare_dummy_transport(): + """Return transport object""" + return DummyTransport() -@pytest.fixture(name="nullmodem") -async def prepare_nullmodem(): - """Prepare nullmodem object.""" - transport = NullModem( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) - transport.__class__.nullmodem_client = None - transport.__class__.nullmodem_server = None - with suppress(RuntimeError): - transport.loop = asyncio.get_running_loop() - return transport +@pytest.fixture(name="cwd_certificate") +def prepare_cwd_certificate(): + """Prepare path to certificate.""" + return os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." -@pytest.fixture(name="nullmodem_server") -async def prepare_nullmodem_server(): - """Prepare nullmodem object.""" - transport = NullModem( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), +@pytest.fixture(name="use_comm_type") +def prepare_dummy_use_comm_type(): + """Return default comm_type""" + return CommType.TCP + + +@pytest.fixture(name="use_host") +def prepare_dummy_use_host(): + """Return default host""" + return "localhost" + + +@pytest.fixture(name="commparams") +def prepare_commparams(use_port, use_host, use_comm_type): + """Prepare CommParamsClass object.""" + return CommParams( + comm_name="test comm", + comm_type=use_comm_type, + reconnect_delay=1, + reconnect_delay_max=3.5, + timeout_connect=2, + host=use_host, + port=use_port, + baudrate=9600, + bytesize=8, + parity="E", + stopbits=2, ) - transport.__class__.nullmodem_client = None - transport.__class__.nullmodem_server = None + + +@pytest.fixture(name="client") +async def prepare_transport(commparams): + """Prepare transport object.""" + transport = Transport(commparams, False) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) + if commparams.comm_type == CommType.TLS: + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + transport.comm_params.generate_ssl( + False, certfile=cwd + "crt", keyfile=cwd + "key" + ) + if commparams.comm_type == CommType.SERIAL: + transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}" return transport -@pytest_asyncio.fixture(name="transport_server") -async def prepare_transport_server(): +@pytest.fixture(name="server") +async def prepare_transport_server(commparams): """Prepare transport object.""" - transport = Transport( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) + if commparams.comm_type == CommType.SERIAL: + commparams = dataclasses.replace(commparams) + commparams.comm_type = CommType.TCP + transport = Transport(commparams, True) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) + if commparams.comm_type == CommType.TLS: + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + transport.comm_params.generate_ssl( + True, certfile=cwd + "crt", keyfile=cwd + "key" + ) + elif commparams.comm_type == CommType.SERIAL: + serial_params = dataclasses.replace(commparams) + serial_params.comm_type = CommType.TCP + transport = Transport(serial_params, True) return transport -@pytest.fixture(name="domain_host") -def get_domain_host(positive): - """Get test host.""" - return "localhost" if positive else "/illegal_host_name" +@pytest.fixture(name="nullmodem") +def prepare_nullmodem(): + """Prepare nullmodem object.""" + return NullModem(False, mock.Mock()) -@pytest.fixture(name="domain_socket") -def get_domain_socket(positive): - """Get test file.""" - return ( - gettempdir() + "/test_unix_" + str(time.time()) - if positive - else "/illegal_file_name" - ) +@pytest.fixture(name="nullmodem_server") +def prepare_nullmodem_server(): + """Prepare nullmodem object.""" + return NullModem(True, mock.Mock()) diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py index 3535b954e0..9fe0d67954 100644 --- a/test/sub_transport/test_basic.py +++ b/test/sub_transport/test_basic.py @@ -3,458 +3,215 @@ from unittest import mock import pytest -from serial import SerialException -from pymodbus.transport.nullmodem import DummyTransport +from pymodbus.transport.transport import NULLMODEM_HOST, CommType, NullModem, Transport -class TestBasicTransport: - """Test transport module, base part.""" +COMM_TYPES = [ + CommType.TCP, + CommType.TLS, + CommType.UDP, + CommType.SERIAL, +] - async def test_init(self, transport, commparams): - """Test init()""" - commparams.done = False - assert transport.comm_params == commparams - assert ( - transport.cb_connection_made._extract_mock_name() # pylint: disable=protected-access - == "cb_connection_made" - ) - assert ( - transport.cb_connection_lost._extract_mock_name() # pylint: disable=protected-access - == "cb_connection_lost" - ) - assert ( - transport.cb_handle_data._extract_mock_name() # pylint: disable=protected-access - == "cb_handle_data" - ) - assert not transport.reconnect_delay_current - assert not transport.reconnect_task - - async def test_property_done(self, transport): - """Test done property""" - transport.comm_params.check_done() - with pytest.raises(RuntimeError): - transport.comm_params.check_done() - async def test_with_magic(self, transport): - """Test magic.""" - transport.close = mock.MagicMock() - async with transport: - pass - transport.close.assert_called_once() +class TestBasicTransport: + """Test transport module.""" - async def test_str_magic(self, params, transport): - """Test magic.""" - assert str(transport) == f"Transport({params.comm_name})" + @pytest.mark.parametrize("use_comm_type", COMM_TYPES) + async def test_init(self, client, server, commparams): + """Test init()""" + if commparams.comm_type == CommType.SERIAL: + client.comm_params.host = commparams.host + server.comm_params.comm_type = commparams.comm_type + client.comm_params.sslctx = None + assert client.comm_params == commparams + assert client.unique_id == str(id(client)) + assert not client.is_server + server.comm_params.sslctx = None + assert server.comm_params == commparams + assert server.unique_id == str(id(server)) + assert server.is_server + + commparams.host = NULLMODEM_HOST + Transport(commparams, False) + + async def test_connect(self, client, dummy_transport): + """Test properties.""" + client.loop = None + client.call_create = mock.AsyncMock(return_value=(dummy_transport, None)) + assert await client.transport_connect() + assert client.loop + client.call_create.side_effect = asyncio.TimeoutError("test") + assert not await client.transport_connect() + + async def test_listen(self, server, dummy_transport): + """Test listen_tcp().""" + server.call_create = mock.AsyncMock(return_value=(dummy_transport, None)) + assert await server.transport_listen() + server.call_create.side_effect = OSError("testing") + assert not await server.transport_listen() - async def test_connection_made(self, transport, commparams): + async def test_connection_made(self, client, commparams, dummy_transport): """Test connection_made().""" - transport.loop = None - transport.connection_made(DummyTransport()) - assert transport.transport - assert not transport.recv_buffer - assert not transport.reconnect_task - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.cb_connection_made.assert_called_once() - transport.cb_connection_lost.assert_not_called() - transport.cb_handle_data.assert_not_called() - transport.close() - - async def test_connection_lost(self, transport): + client.connection_made(dummy_transport) + assert client.transport + assert not client.recv_buffer + assert not client.reconnect_task + assert client.reconnect_delay_current == commparams.reconnect_delay + client.callback_connected.assert_called_once() + + async def test_connection_lost(self, client, dummy_transport): """Test connection_lost().""" - transport.connection_lost(RuntimeError("not implemented")) - assert not transport.transport - assert not transport.recv_buffer - assert not transport.reconnect_task - assert not transport.reconnect_delay_current - transport.cb_connection_made.assert_not_called() - transport.cb_handle_data.assert_not_called() - transport.cb_connection_lost.assert_called_once() - - transport.transport = mock.Mock() - transport.connection_lost(RuntimeError("not implemented")) - assert not transport.transport - assert transport.reconnect_task - transport.close() - assert not transport.reconnect_task - - async def test_close(self, transport): - """Test close().""" - socket = DummyTransport() - socket.abort = mock.Mock() - socket.close = mock.Mock() - transport.connection_made(socket) - transport.cb_connection_made.reset_mock() - transport.cb_connection_lost.reset_mock() - transport.cb_handle_data.reset_mock() - transport.recv_buffer = b"abc" - transport.reconnect_task = mock.MagicMock() - transport.close() - socket.abort.assert_called_once() - socket.close.assert_called_once() - transport.cb_connection_made.assert_not_called() - transport.cb_connection_lost.assert_not_called() - transport.cb_handle_data.assert_not_called() - assert not transport.recv_buffer - assert not transport.reconnect_task - - async def test_reset_delay(self, transport, commparams): - """Test reset_delay().""" - transport.reconnect_delay_current += 5.17 - transport.reset_delay() - assert transport.reconnect_delay_current == commparams.reconnect_delay - - async def test_datagram(self, transport): + client.connection_lost(RuntimeError("not implemented")) + client.connection_made(dummy_transport) + client.connection_lost(RuntimeError("not implemented")) + assert not client.transport + assert not client.recv_buffer + assert client.reconnect_task + client.callback_disconnected.assert_called_once() + client.close() + assert not client.reconnect_task + assert not client.reconnect_delay_current + + async def test_data_received(self, client): + """Test data_received.""" + client.callback_data = mock.MagicMock(return_value=2) + client.data_received(b"123456") + client.callback_data.assert_called_once() + assert client.recv_buffer == b"3456" + client.data_received(b"789") + assert client.recv_buffer == b"56789" + + async def test_datagram(self, client): """Test datagram_received().""" - transport.data_received = mock.MagicMock() - transport.datagram_received(b"abc", "127.0.0.1") - transport.data_received.assert_called_once() + client.data_received = mock.MagicMock() + client.datagram_received(b"abc", "127.0.0.1") + client.data_received.assert_called_once() - async def test_data(self, transport): - """Test data_received.""" - transport.cb_handle_data = mock.MagicMock(return_value=2) - transport.data_received(b"123456") - transport.cb_handle_data.assert_called_once() - assert transport.recv_buffer == b"3456" - transport.data_received(b"789") - assert transport.recv_buffer == b"56789" - - async def test_eof_received(self, transport): + async def test_eof_received(self, client): """Test eof_received.""" - transport.eof_received() + client.eof_received() - async def test_error_received(self, transport): + async def test_error_received(self, client): """Test error_received.""" with pytest.raises(RuntimeError): - transport.error_received(Exception("test call")) + client.error_received(Exception("test call")) - async def test_send(self, transport, params): + async def test_callbacks(self, commparams): + """Test callbacks.""" + client = Transport(commparams, False) + client.callback_connected() + client.callback_disconnected(Exception("test")) + client.callback_data(b"abcd") + + async def test_send(self, client): """Test send().""" - transport.transport = mock.AsyncMock() - await transport.send(b"abc") + client.transport = mock.AsyncMock() + client.send(b"abc") + + client.comm_params.comm_type = CommType.UDP + client.send(b"abc") + client.send(b"abc", addr=("localhost", 502)) - transport.setup_udp(False, params.host, params.port) - await transport.send(b"abc") - transport.close() + async def test_close(self, server, dummy_transport): + """Test close().""" + dummy_transport.abort = mock.Mock() + dummy_transport.close = mock.Mock() + server.connection_made(dummy_transport) + server.recv_buffer = b"abc" + server.reconnect_task = mock.MagicMock() + server.listener = mock.MagicMock() + server.close() + dummy_transport.abort.assert_called_once() + dummy_transport.close.assert_called_once() + assert not server.recv_buffer + assert not server.reconnect_task + server.listener = None + server.active_connections = {"a": dummy_transport} + server.close() + assert not server.active_connections + + async def test_reset_delay(self, client, commparams): + """Test reset_delay().""" + client.reconnect_delay_current += 5.17 + client.reset_delay() + assert client.reconnect_delay_current == commparams.reconnect_delay + + async def test_is_active(self, client): + """Test is_active().""" + assert not client.is_active() + client.connection_made(mock.AsyncMock()) + assert client.is_active() + + @pytest.mark.parametrize("use_host", [NULLMODEM_HOST]) + async def test_create_nullmodem(self, client, server): + """Test create_nullmodem.""" + await server.transport_listen() + await client.transport_listen() + + async def test_handle_new_connection(self, client, server): + """Test handle_new_connection().""" + server.handle_new_connection() + client.handle_new_connection() + + async def test_do_reconnect(self, client): + """Test do_reconnect().""" + client.comm_params.reconnect_delay = 0.01 + client.transport_connect = mock.AsyncMock(side_effect=[False, True]) + await client.do_reconnect() + assert client.reconnect_delay_current == client.comm_params.reconnect_delay * 2 + assert not client.reconnect_task + client.transport_connect.side_effect = asyncio.CancelledError("stop loop") + await client.do_reconnect() + assert client.reconnect_delay_current == client.comm_params.reconnect_delay + assert not client.reconnect_task + + async def test_with_magic(self, client): + """Test magic.""" + client.close = mock.MagicMock() + async with client: + pass + client.close.assert_called_once() - async def test_handle_listen(self, transport): - """Test handle_listen().""" - assert transport == transport.handle_listen() + async def test_str_magic(self, commparams, client): + """Test magic.""" + assert str(client) == f"Transport({commparams.comm_name})" - async def test_no_loop(self, transport): - """Test properties.""" - transport.loop = None - transport.call_connect_listen = mock.AsyncMock(return_value=(117, 118)) - await transport.transport_connect() - assert transport.loop - - async def test_reconnect_connect(self, transport): - """Test handle_listen().""" - transport.comm_params.reconnect_delay = 0.01 - transport.transport_connect = mock.AsyncMock(side_effect=[False, True]) - await transport.reconnect_connect() - assert ( - transport.reconnect_delay_current - == transport.comm_params.reconnect_delay * 2 - ) - assert not transport.reconnect_task - transport.transport_connect = mock.AsyncMock( - side_effect=asyncio.CancelledError("stop loop") - ) - await transport.reconnect_connect() - assert ( - transport.reconnect_delay_current == transport.comm_params.reconnect_delay - ) - assert not transport.reconnect_task - - -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented") -class TestBasicUnixTransport: - """Test transport module, unix part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_unix(setup_server, params.host) - commparams.host = params.host - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties_windows(self, params, setup_server, transport): - """Test properties.""" - with mock.patch( - "pymodbus.transport.transport.sys.platform", return_value="windows" - ), pytest.raises(RuntimeError): - transport.setup_unix(setup_server, params.host) - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_unix().""" - transport.setup_unix(False, params.host) - mocker = mock.AsyncMock() - transport.loop.create_unix_connection = mocker - mocker.side_effect = FileNotFoundError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_unix().""" - transport.setup_unix(True, params.host) - mocker = mock.AsyncMock() - transport.loop.create_unix_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() - - -class TestBasicTcpTransport: - """Test transport module, tcp part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_tcp(setup_server, params.host, params.port) - commparams.host = params.host - commparams.port = params.port - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_tcp().""" - transport.setup_tcp(False, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_tcp().""" - transport.setup_tcp(True, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_is_active(self, params, transport): - """Test properties.""" - transport.setup_tcp(False, params.host, params.port) - assert not transport.is_active() - transport.connection_made(mock.AsyncMock()) - assert transport.is_active() - transport.close() + def test_generate_ssl(self, commparams): + """Test ssl generattion""" + with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): + commparams.generate_ssl(True, "cert_file", "key_file") + assert commparams.sslctx -class TestBasicTlsTransport: - """Test transport module, tls part.""" +class TestBasicNullModem: + """Test transport null modem module.""" - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - @pytest.mark.parametrize("sslctx", [None, "test ctx"]) - def test_properties(self, setup_server, sslctx, params, transport, commparams): - """Test properties.""" - with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): - transport.setup_tls( - setup_server, - params.host, - params.port, - sslctx, - "certfile dummy", - None, - None, - params.server_hostname, - ) - commparams.host = params.host - commparams.port = params.port - commparams.server_hostname = params.server_hostname - commparams.ssl = sslctx if sslctx else transport.comm_params.ssl - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_tcls().""" - transport.setup_tls( - False, - params.host, - params.port, - "no ssl", - None, - None, - None, - params.server_hostname, - ) - mocker = mock.AsyncMock() - transport.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_tls().""" - transport.setup_tls( - True, - params.host, - params.port, - "no ssl", - None, - None, - None, - params.server_hostname, - ) - mocker = mock.AsyncMock() - transport.loop.create_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() - - -class TestBasicUdpTransport: - """Test transport module, udp part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_udp(setup_server, params.host, params.port) - commparams.host = params.host - commparams.port = params.port - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_udp().""" - transport.setup_udp(False, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_datagram_endpoint = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_udp().""" - transport.setup_udp(True, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_datagram_endpoint = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_listen() == mocker.return_value[0] - transport.close() - - -class TestBasicSerialTransport: - """Test transport module, serial part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_serial( - setup_server, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - commparams.host = params.host - commparams.baudrate = params.baudrate - commparams.bytesize = params.bytesize - commparams.parity = params.parity - commparams.stopbits = params.stopbits - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_serial().""" - transport.setup_serial( - False, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - mocker = mock.AsyncMock() - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_serial().""" - transport.setup_serial( - True, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - mocker = mock.AsyncMock() - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = SerialException("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert await transport.transport_listen() == mocker.return_value - transport.close() + def test_init(self): + """Test null modem init""" + NullModem.server = None + with pytest.raises(RuntimeError): + NullModem(False, mock.Mock()) + NullModem(True, mock.Mock()) + NullModem(False, mock.Mock()) + + def test_external_methods(self): + """Test external methods.""" + modem = NullModem(True, mock.Mock()) + modem.close() + modem.sendto(b"abcd") + modem.write(b"abcd") + + def test_abstract_methods(self): + """Test asyncio abstract methods.""" + modem = NullModem(True, mock.Mock()) + modem.abort() + modem.can_write_eof() + modem.get_write_buffer_size() + modem.get_write_buffer_limits() + modem.set_write_buffer_limits(1024, 1) + modem.write_eof() + modem.get_protocol() + modem.set_protocol(None) + modem.is_closing() diff --git a/test/sub_transport/test_comm.py b/test/sub_transport/test_comm.py index bfbfc5f8dc..76905c4b71 100644 --- a/test/sub_transport/test_comm.py +++ b/test/sub_transport/test_comm.py @@ -1,241 +1,132 @@ """Test transport.""" +import asyncio import time import pytest +from pymodbus.transport.transport import CommType -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented.") -class TestCommUnixTransport: - """Test for the transport module.""" - - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, domain_socket): - """Test connect_unix().""" - transport.setup_unix(False, domain_socket) - start = time.time() - assert not await transport.transport_connect() - delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, positive, domain_socket): - """Test listen_unix().""" - transport_server.setup_unix(True, domain_socket) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - if server: - server.close() - transport_server.close() +BASE_PORT = 6100 - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, domain_socket): - """Test listen/connect unix().""" - transport_server.setup_unix(True, domain_socket) - await transport_server.transport_listen() - transport.setup_unix(False, domain_socket) - assert await transport.transport_connect() - transport.close() - transport_server.close() - - -class TestCommTcpTransport: +class TestCommTransport: """Test for the transport module.""" - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, use_port, domain_host): - """Test connect_tcp().""" - transport.setup_tcp(False, domain_host, use_port) + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 1), + (CommType.TLS, BASE_PORT + 2), + # (CommType.UDP, BASE_PORT + 3), udp is connectionless. + (CommType.SERIAL, BASE_PORT + 4), + ], + ) + async def test_connect(self, client): + """Test connect().""" start = time.time() - assert not await transport.transport_connect() + assert not await client.transport_connect() delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, use_port, positive, domain_host): - """Test listen_tcp().""" - transport_server.setup_tcp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, use_port, domain_host): - """Test listen/connect tcp().""" - transport_server.setup_tcp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_tcp(False, domain_host, use_port) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() - - -class TestCommTlsTransport: - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, params, domain_host): - """Test connect_tls().""" - transport.setup_tls( - False, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) + assert delta < client.comm_params.timeout_connect * 1.2 + client.close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 5), + (CommType.TLS, BASE_PORT + 6), + # (CommType.UDP, BASE_PORT + 7), udp is connectionless. + (CommType.SERIAL, BASE_PORT + 8), + ], + ) + async def test_connect_not_ok(self, client): + """Test connect().""" + client.comm_params.host = "/illegal_host" start = time.time() - assert not await transport.transport_connect() + assert not await client.transport_connect() delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, params, positive, domain_host): - """Test listen_tls().""" - transport_server.setup_tls( - True, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, params, domain_host): - """Test listen/connect tls().""" - transport_server.setup_tls( - True, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) - server = await transport_server.transport_listen() - assert server - - transport.setup_tcp(False, domain_host, params.port) - assert await transport.transport_connect() - transport.close() - transport_server.close() + assert delta < client.comm_params.timeout_connect * 1.2 + client.close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 9), + (CommType.TLS, BASE_PORT + 10), + (CommType.UDP, BASE_PORT + 11), + # (CommType.SERIAL, BASE_PORT + 12), there are no standard tty port + ], + ) + async def test_listen(self, server): + """Test listen().""" + assert await server.transport_listen() + assert server.transport server.close() - -class TestCommUdpTransport: - """Test for the transport module.""" - - async def test_connect(self): - """Test connect_udp().""" - # always true, since udp is connectionless. - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, use_port, positive, domain_host): - """Test listen_udp().""" - transport_server.setup_udp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, use_port, domain_host): - """Test listen/connect udp().""" - transport_server.setup_udp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_udp(False, domain_host, use_port) - assert await transport.transport_connect() - transport.close() - transport_server.close() + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 13), + (CommType.TLS, BASE_PORT + 14), + (CommType.UDP, BASE_PORT + 15), + (CommType.SERIAL, BASE_PORT + 16), + ], + ) + async def test_listen_not_ok(self, server): + """Test listen().""" + server.comm_params.host = "/illegal_host" + assert not await server.transport_listen() + assert not server.transport server.close() + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 13), + (CommType.TLS, BASE_PORT + 14), + (CommType.UDP, BASE_PORT + 15), + (CommType.SERIAL, BASE_PORT + 16), + ], + ) + async def test_connected(self, client, server, use_comm_type): + """Test connection and data exchange.""" + assert await server.transport_listen() + assert await client.transport_connect() + await asyncio.sleep(0.5) + assert len(server.active_connections) == 1 + server_connected = list(server.active_connections.values())[0] + test_data = b"abcd" + client.send(test_data) + await asyncio.sleep(0.5) + assert server_connected.recv_buffer == test_data + assert not client.recv_buffer + server_connected.recv_buffer = b"" + if use_comm_type == CommType.UDP: + sock = client.transport.get_extra_info("socket") + addr = sock.getsockname() + server_connected.send(test_data, addr=addr) + else: + server_connected.send(test_data) + await asyncio.sleep(2) + assert client.recv_buffer == test_data + assert not server_connected.recv_buffer + client.close() + server.close() -class TestCommSerialTransport: - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, use_port, positive): - """Test connect_serial().""" - domain_port = f"unix:/localhost:{use_port}" if positive else "/illegal_port" - transport.setup_serial( - False, - domain_port, - 9600, - 8, - "E", - 2, - ) - start = time.time() - assert not await transport.transport_connect() - delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - async def test_listen(self, transport_server): - """Test listen_serial().""" - transport_server.setup_serial( - True, - "/illegal_port", - 9600, - 8, - "E", - 2, - ) - server = await transport_server.transport_listen() - assert not server - assert not transport_server.transport - transport_server.close() +class TestCommNullModem: # pylint: disable=too-few-public-methods + """Test null modem module.""" - # there are no positive test, since there are no standard tty port + def test_class_variables(self, nullmodem_server, nullmodem): + """Test connection_made().""" + assert nullmodem.client + assert nullmodem.server + assert nullmodem_server.client + assert nullmodem_server.server + nullmodem.__class__.client = self + nullmodem.is_server = False + nullmodem_server.__class__.server = self + nullmodem_server.is_server = True - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected(self, transport, transport_server, use_port): - """Test listen/connect serial().""" - transport_server.setup_tcp(True, "localhost", use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_serial( - False, - f"socket://localhost:{use_port}", - 9600, - 8, - "E", - 2, - ) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() + assert nullmodem.client == nullmodem_server.client + assert nullmodem.server == nullmodem_server.server diff --git a/test/sub_transport/test_data.py b/test/sub_transport/test_data.py deleted file mode 100644 index 6d9535e4d9..0000000000 --- a/test/sub_transport/test_data.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Test transport.""" -import asyncio - -import pytest - - -class TestDataTransport: # pylint: disable=too-few-public-methods - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - async def test_client_send(self, transport, transport_server, use_port): - """Test send().""" - transport_server.setup_tcp(True, "localhost", use_port) - server = await transport_server.transport_listen() - assert transport_server.transport - - transport.setup_tcp(False, "localhost", use_port) - assert await transport.transport_connect() - await transport.send(b"ABC") - await asyncio.sleep(2) - assert transport_server.recv_buffer == b"ABC" - await transport_server.send(b"DEF") - await asyncio.sleep(2) - assert transport.recv_buffer == b"DEF" - transport.close() - transport_server.close() - server.close() diff --git a/test/sub_transport/test_nullmodem.py b/test/sub_transport/test_nullmodem.py deleted file mode 100644 index 3e61765e7d..0000000000 --- a/test/sub_transport/test_nullmodem.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test transport.""" - -from pymodbus.transport.nullmodem import DummyTransport - - -class TestNullModemTransport: - """Test null modem module.""" - - async def test_str_magic(self, nullmodem, params): - """Test magic.""" - str(nullmodem) - assert str(nullmodem) == f"NullModem({params.comm_name})" - - def test_DummyTransport(self): - """Test DummyTransport class.""" - socket = DummyTransport() - socket.close() - socket.get_protocol() - socket.is_closing() - socket.set_protocol(None) - socket.abort() - - def test_class_variables(self, nullmodem, nullmodem_server): - """Test connection_made().""" - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - assert not nullmodem_server.nullmodem_client - assert not nullmodem_server.nullmodem_server - nullmodem.__class__.nullmodem_client = self - nullmodem.is_server = False - nullmodem_server.__class__.nullmodem_server = self - nullmodem_server.is_server = True - - assert nullmodem.nullmodem_client == nullmodem_server.nullmodem_client - assert nullmodem.nullmodem_server == nullmodem_server.nullmodem_server - - async def test_transport_connect(self, nullmodem): - """Test connection_made().""" - nullmodem.loop = None - assert not await nullmodem.transport_connect() - assert not nullmodem.nullmodem_server - assert not nullmodem.nullmodem_client - assert nullmodem.loop - nullmodem.cb_connection_made.assert_not_called() - nullmodem.cb_connection_lost.assert_not_called() - nullmodem.cb_handle_data.assert_not_called() - - async def test_transport_listen(self, nullmodem_server): - """Test connection_made().""" - nullmodem_server.loop = None - assert await nullmodem_server.transport_listen() - assert nullmodem_server.is_server - assert nullmodem_server.nullmodem_server - assert not nullmodem_server.nullmodem_client - assert nullmodem_server.loop - nullmodem_server.cb_connection_made.assert_not_called() - nullmodem_server.cb_connection_lost.assert_not_called() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_connected(self, nullmodem, nullmodem_server): - """Test connection is correct.""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - assert nullmodem.nullmodem_client - assert nullmodem.nullmodem_server - assert nullmodem.loop - assert not nullmodem.is_server - assert nullmodem_server.is_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_not_called() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_not_called() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_client_close(self, nullmodem, nullmodem_server): - """Test close().""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - nullmodem.close() - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_called_once() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_called_once() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_server_close(self, nullmodem, nullmodem_server): - """Test close().""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - nullmodem_server.close() - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_called_once() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_called_once() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_data(self, nullmodem, nullmodem_server): - """Test data exchange.""" - data = b"abcd" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - assert await nullmodem.send(data) - assert nullmodem_server.recv_buffer == data - assert not nullmodem.recv_buffer - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_handle_data.assert_called_once() - assert await nullmodem_server.send(data) - assert nullmodem_server.recv_buffer == data - assert nullmodem.recv_buffer == data - nullmodem.cb_handle_data.assert_called_once() - nullmodem_server.cb_handle_data.assert_called_once() diff --git a/test/sub_transport/test_reconnect.py b/test/sub_transport/test_reconnect.py index defb815c93..dd3bf24806 100644 --- a/test/sub_transport/test_reconnect.py +++ b/test/sub_transport/test_reconnect.py @@ -2,68 +2,60 @@ import asyncio from unittest import mock -import pytest - class TestReconnectTransport: """Test transport module, base part.""" - @pytest.mark.xdist_group(name="server_serialize") - async def test_no_reconnect_call(self, transport, use_port, commparams): + async def test_no_reconnect_call(self, client): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - assert not mocker.call_count - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.connection_lost(RuntimeError("Connection lost")) - assert not mocker.call_count - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_lost(RuntimeError("Connection lost")) + assert not client.reconnect_task + assert client.loop.create_connection.call_count + assert not client.reconnect_delay_current + client.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_reconnect_call(self, transport, use_port, commparams): + async def test_reconnect_call(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + assert client.reconnect_task + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.reconnect_task + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + client.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_multi_reconnect_call(self, transport, use_port, commparams): + async def test_multi_reconnect_call(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 2 - assert transport.reconnect_delay_current == commparams.reconnect_delay_max - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count >= 3 - assert transport.reconnect_delay_current == commparams.reconnect_delay_max - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 3 + assert client.reconnect_delay_current == commparams.reconnect_delay_max + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count >= 4 + assert client.reconnect_delay_current == commparams.reconnect_delay_max + client.close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_reconnect_call_ok(self, transport, use_port, commparams): + async def test_reconnect_call_ok(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(mock.Mock(), mock.Mock())) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay - assert not transport.reconnect_task - transport.close() + client.loop.create_connection = mock.AsyncMock( + return_value=(mock.Mock(), mock.Mock()) + ) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay + assert not client.reconnect_task + client.close()