From 71ceb06305ef4ad8bbad582d8641dc268e0321cb Mon Sep 17 00:00:00 2001 From: jan iversen Date: Fri, 23 Jun 2023 20:07:41 +0200 Subject: [PATCH] Integrate transport in server. (#1617) --- API_changes.rst | 6 +- doc/source/library/simulator/config.rst | 5 - examples/contrib/serial_forwarder.py | 3 +- examples/datastore_simulator.py | 1 - examples/server_async.py | 3 - examples/server_sync.py | 3 - pymodbus/client/base.py | 66 +- pymodbus/client/serial.py | 24 +- pymodbus/client/tcp.py | 46 +- pymodbus/client/tls.py | 65 +- pymodbus/client/udp.py | 16 +- pymodbus/server/__init__.py | 4 - pymodbus/server/async_io.py | 411 ++++-------- pymodbus/server/reactive/default_config.json | 4 - pymodbus/server/reactive/default_config.py | 4 - pymodbus/server/simulator/setup.json | 3 - pymodbus/transport/nullmodem.py | 115 ---- pymodbus/transport/transport.py | 518 ++++++++------- test/sub_examples/conftest.py | 9 +- test/sub_examples/test_client_server_async.py | 35 +- test/sub_examples/test_client_server_sync.py | 35 +- test/sub_examples/test_examples.py | 48 +- test/sub_transport/conftest.py | 201 +++--- test/sub_transport/test_basic.py | 627 ++++++------------ test/sub_transport/test_comm.py | 344 ++++------ test/sub_transport/test_data.py | 27 - test/sub_transport/test_nullmodem.py | 118 ---- test/sub_transport/test_reconnect.py | 100 ++- test/test_client.py | 24 +- test/test_client_sync.py | 24 - test/test_server_asyncio.py | 6 +- test/test_unix_socket.py | 61 -- 32 files changed, 1034 insertions(+), 1922 deletions(-) delete mode 100644 pymodbus/transport/nullmodem.py delete mode 100644 test/sub_transport/test_data.py delete mode 100644 test/sub_transport/test_nullmodem.py delete mode 100755 test/test_unix_socket.py diff --git a/API_changes.rst b/API_changes.rst index 79d3fee46..4895f3461 100644 --- a/API_changes.rst +++ b/API_changes.rst @@ -6,7 +6,11 @@ PyModbus - API changes. Version 3.4.0 ------------- - ModbusClient .connect() returns True/False (connected or not) -- ModbueServer handler= no longer accepted +- ModbueServer handler=, allow_reuse_addr=, backlog= are no longer accepted +- ModbusTcpClient / AsyncModbusTcpClient no longer support unix path +- StartAsyncUnixServer / ModbusUnixServer removed (never worked on Windows) +- ModbusTlsServer reqclicert= is not longer accepted + ------------- Version 3.3.1 diff --git a/doc/source/library/simulator/config.rst b/doc/source/library/simulator/config.rst index ceb9e3229..9bfcaef22 100644 --- a/doc/source/library/simulator/config.rst +++ b/doc/source/library/simulator/config.rst @@ -44,7 +44,6 @@ The entries for a tcp server with minimal parameters look like: "comm": "tcp", "host": "0.0.0.0", "port": 5020, - "allow_reuse_address": true, "framer": "socket", } } @@ -60,7 +59,6 @@ The entry “comm” allows the following values: - “serial”, to use :class:`pymodbus.server.ModbusSerialServer`, - “tcp”, to use :class:`pymodbus.server.ModbusTcpServer`, - “tls”, to use :class:`pymodbus.server.ModbusTlsServer`, -- “unix”, to use :class:`pymodbus.server.ModbusUnixServer`, - “udp”; to use :class:`pymodbus.server.ModbusUdpServer`. The entry “framer” allows the following values: @@ -87,7 +85,6 @@ Server configuration examples "comm": "tcp", "host": "0.0.0.0", "port": 5020, - "allow_reuse_address": true, "ignore_missing_slaves": false, "framer": "socket", "identity": { @@ -125,8 +122,6 @@ Server configuration examples "port": 5020, "certfile": "certificates/pymodbus.crt", "keyfile": "certificates/pymodbus.key", - "allow_reuse_address": true, - "backlog": 20, "ignore_missing_slaves": false, "framer": "tls", "identity": { diff --git a/examples/contrib/serial_forwarder.py b/examples/contrib/serial_forwarder.py index b29fd824d..7cb42acad 100644 --- a/examples/contrib/serial_forwarder.py +++ b/examples/contrib/serial_forwarder.py @@ -41,7 +41,8 @@ async def run(self): store[i] = RemoteSlaveContext(client, slave=i) context = ModbusServerContext(slaves=store, single=False) self.server = ModbusTcpServer( - context, address=(server_ip, server_port), allow_reuse_address=True + context, + address=(server_ip, server_port), ) message = f"serving on {server_ip} port {server_port}" _logger.info(message) diff --git a/examples/datastore_simulator.py b/examples/datastore_simulator.py index 8c64413e3..b95039b7a 100755 --- a/examples/datastore_simulator.py +++ b/examples/datastore_simulator.py @@ -169,7 +169,6 @@ async def run_server_simulator(args): context=args.context, address=("", args.port), framer=args.framer, - allow_reuse_address=True, ) diff --git a/examples/server_async.py b/examples/server_async.py index 1458ecd75..4939c6b8d 100755 --- a/examples/server_async.py +++ b/examples/server_async.py @@ -154,7 +154,6 @@ async def run_async_server(args): address=address, # listen address # custom_functions=[], # allow custom handling framer=args.framer, # The framer strategy to use - allow_reuse_address=True, # allow the reuse of an address # ignore_missing_slaves=True, # ignore request to a missing slave # broadcast_enable=False, # treat slave_id 0 as broadcast address, # timeout=1, # waiting time for request to complete @@ -202,7 +201,6 @@ async def run_async_server(args): # custom_functions=[], # allow custom handling address=address, # listen address framer=args.framer, # The framer strategy to use - allow_reuse_address=True, # allow the reuse of an address certfile=helper.get_certificate( "crt" ), # The cert file path for TLS (used if sslctx is None) @@ -211,7 +209,6 @@ async def run_async_server(args): "key" ), # The key file path for TLS (used if sslctx is None) # password="none", # The password for for decrypting the private key file - # reqclicert=False, # Force the sever request client"s certificate # ignore_missing_slaves=True, # ignore request to a missing slave # broadcast_enable=False, # treat slave_id 0 as broadcast address, # timeout=1, # waiting time for request to complete diff --git a/examples/server_sync.py b/examples/server_sync.py index 24d461ef7..e193dce1d 100755 --- a/examples/server_sync.py +++ b/examples/server_sync.py @@ -66,7 +66,6 @@ def run_sync_server(args): address=address, # listen address # custom_functions=[], # allow custom handling framer=args.framer, # The framer strategy to use - allow_reuse_address=True, # allow the reuse of an address # ignore_missing_slaves=True, # ignore request to a missing slave # broadcast_enable=False, # treat slave_id 0 as broadcast address, # timeout=1, # waiting time for request to complete @@ -114,7 +113,6 @@ def run_sync_server(args): # custom_functions=[], # allow custom handling address=address, # listen address framer=args.framer, # The framer strategy to use - allow_reuse_address=True, # allow the reuse of an address certfile=helper.get_certificate( "crt" ), # The cert file path for TLS (used if sslctx is None) @@ -123,7 +121,6 @@ def run_sync_server(args): "key" ), # The key file path for TLS (used if sslctx is None) # password=None, # The password for for decrypting the private key file - # reqclicert=False, # Force the sever request client"s certificate # ignore_missing_slaves=True, # ignore request to a missing slave # broadcast_enable=False, # treat slave_id 0 as broadcast address, # timeout=1, # waiting time for request to complete diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index 6392fc71b..dfe2145a3 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -13,11 +13,11 @@ from pymodbus.logging import Log from pymodbus.pdu import ModbusRequest, ModbusResponse from pymodbus.transaction import DictTransactionManager -from pymodbus.transport.transport import Transport +from pymodbus.transport.transport import CommParams, Transport from pymodbus.utilities import ModbusTransactionState -class ModbusBaseClient(ModbusClientMixin): +class ModbusBaseClient(ModbusClientMixin, Transport): """**ModbusBaseClient** **Parameters common to all clients**: @@ -49,19 +49,17 @@ class ModbusBaseClient(ModbusClientMixin): """ @dataclass - class _params: # pylint: disable=too-many-instance-attributes + class _params: """Parameter class.""" host: str = None port: str | int = None - framer: type[ModbusFramer] = None timeout: float = None retries: int = None retry_on_empty: bool = None close_comm_on_error: bool = None strict: bool = None broadcast_enable: bool = None - kwargs: dict = None reconnect_delay: int = None baudrate: int = None @@ -72,10 +70,6 @@ class _params: # pylint: disable=too-many-instance-attributes source_address: tuple[str, int] = None - sslctx: str = None - certfile: str = None - keyfile: str = None - password: str = None server_hostname: str = None def __init__( # pylint: disable=too-many-arguments @@ -93,19 +87,27 @@ def __init__( # pylint: disable=too-many-arguments **kwargs: Any, ) -> None: """Initialize a client instance.""" - self.new_transport = Transport( - "comm", - reconnect_delay * 1000, - reconnect_delay_max * 1000, - timeout * 1000, - lambda: None, - self.cb_base_connection_lost, - self.cb_base_handle_data, + ModbusClientMixin.__init__(self) + Transport.__init__( + self, + CommParams( + comm_type=kwargs.get("CommType"), + comm_name="comm", + reconnect_delay=reconnect_delay, + reconnect_delay_max=reconnect_delay_max, + timeout_connect=timeout, + host=kwargs.get("host", None), + port=kwargs.get("port", None), + sslctx=kwargs.get("sslctx", None), + baudrate=kwargs.get("baudrate", None), + bytesize=kwargs.get("bytesize", None), + parity=kwargs.get("parity", None), + stopbits=kwargs.get("stopbits", None), + ), + False, ) - self.framer = framer self.params = self._params() - self.params.framer = framer self.params.timeout = float(timeout) self.params.retries = int(retries) self.params.retry_on_empty = bool(retry_on_empty) @@ -115,7 +117,6 @@ def __init__( # pylint: disable=too-many-arguments self.params.reconnect_delay = int(reconnect_delay) self.reconnect_delay_max = int(reconnect_delay_max) self.on_reconnect_callback = on_reconnect_callback - self.params.kwargs = kwargs self.retry_on_empty: int = 0 # -> retry read on nothing @@ -123,7 +124,7 @@ def __init__( # pylint: disable=too-many-arguments # -> list of acceptable slaves (0 for accept all) # Common variables. - self.framer = self.params.framer(ClientDecoder(), self) + self.framer = framer(ClientDecoder(), self) self.transaction = DictTransactionManager( self, retries=retries, retry_on_empty=retry_on_empty, **kwargs ) @@ -135,9 +136,6 @@ def __init__( # pylint: disable=too-many-arguments self.last_frame_end: float = 0 self.silent_interval: float = 0 - # Initialize mixin - ModbusClientMixin.__init__(self) - # ----------------------------------------------------------------------- # # Client external interface # ----------------------------------------------------------------------- # @@ -152,9 +150,9 @@ def register(self, custom_response_class: ModbusResponse) -> None: """ self.framer.decoder.register(custom_response_class) - def close(self, reconnect: bool = False) -> None: + def close(self, reconnect=False) -> None: """Close connection.""" - self.new_transport.close(reconnect=reconnect) + self.transport_close(reconnect=reconnect) def idle_time(self) -> float: """Time before initiating next transaction (call **sync**). @@ -177,7 +175,7 @@ def execute(self, request: ModbusRequest = None) -> ModbusResponse: if not self.connect(): raise ConnectionException(f"Failed to connect[{str(self)}]") return self.transaction.execute(request) - if not self.new_transport.transport: + if not self.transport: raise ConnectionException(f"Not connected[{str(self)}]") return self.async_execute(request) @@ -189,11 +187,7 @@ async def async_execute(self, request=None): request.transaction_id = self.transaction.getNextTID() packet = self.framer.buildPacket(request) Log.debug("send: {}", packet, ":hex") - # if self.use_udp: - # self.new_transport.transport.sendto(packet) - # else: - # self.new_transport.transport.write(packet) - await self.new_transport.send(packet) + self.transport_send(packet) req = self._build_response(request.transaction_id) if self.params.broadcast_enable and not request.slave_id: resp = b"Broadcast write sent - no response expected" @@ -205,16 +199,16 @@ async def async_execute(self, request=None): raise return resp - def cb_base_handle_data(self, data: bytes) -> int: + def callback_data(self, data: bytes, addr: tuple = None) -> int: """Handle received data returns number of bytes consumed """ - Log.debug("recv: {}", data, ":hex") + Log.debug("recv: {} addr={}", data, ":hex", addr) self.framer.processIncomingPacket(data, self._handle_response, slave=0) return len(data) - def cb_base_connection_lost(self, _reason: Exception) -> None: + def callback_disconnected(self, _reason: Exception) -> None: """Handle lost connection""" for tid in list(self.transaction): self.raise_future( @@ -243,7 +237,7 @@ def _handle_response(self, reply, **_kwargs): def _build_response(self, tid): """Return a deferred response for the current request.""" my_future = asyncio.Future() - if not self.new_transport.transport: + if not self.transport: self.raise_future(my_future, ConnectionException("Client is not connected")) else: self.transaction.addTransaction(my_future, tid) diff --git a/pymodbus/client/serial.py b/pymodbus/client/serial.py index ca8fbd32b..28d9651e2 100644 --- a/pymodbus/client/serial.py +++ b/pymodbus/client/serial.py @@ -10,6 +10,7 @@ from pymodbus.framer import ModbusFramer from pymodbus.framer.rtu_framer import ModbusRtuFramer from pymodbus.logging import Log +from pymodbus.transport.transport import CommType from pymodbus.utilities import ModbusTransactionState @@ -56,31 +57,38 @@ def __init__( ) -> None: """Initialize Asyncio Modbus Serial Client.""" asyncio.Protocol.__init__(self) - ModbusBaseClient.__init__(self, framer=framer, **kwargs) + ModbusBaseClient.__init__( + self, + framer=framer, + CommType=CommType.SERIAL, + host=port, + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + **kwargs, + ) self.params.port = port self.params.baudrate = baudrate self.params.bytesize = bytesize self.params.parity = parity self.params.stopbits = stopbits self.params.handle_local_echo = handle_local_echo - self.new_transport.setup_serial( - False, port, baudrate, bytesize, parity, stopbits - ) @property def connected(self): """Connect internal.""" - return self.new_transport.is_active() + return self.is_active() async def connect(self) -> bool: """Connect Async client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again # so this instance will work - self.new_transport.reset_delay() + self.reset_delay() # force reconnect if required: - Log.debug("Connecting to {}.", self.new_transport.comm_params.host) - return await self.new_transport.transport_connect() + Log.debug("Connecting to {}.", self.comm_params.host) + return await self.transport_connect() class ModbusSerialClient(ModbusBaseClient): diff --git a/pymodbus/client/tcp.py b/pymodbus/client/tcp.py index c3c3a9d98..db8c2823d 100644 --- a/pymodbus/client/tcp.py +++ b/pymodbus/client/tcp.py @@ -10,6 +10,7 @@ from pymodbus.framer import ModbusFramer from pymodbus.framer.socket_framer import ModbusSocketFramer from pymodbus.logging import Log +from pymodbus.transport.transport import CommType from pymodbus.utilities import ModbusTransactionState @@ -22,8 +23,6 @@ class AsyncModbusTcpClient(ModbusBaseClient, asyncio.Protocol): :param source_address: (optional) source address of client :param kwargs: (optional) Experimental parameters - using unix domain socket can be achieved by setting host="unix:" - Example:: from pymodbus.client import AsyncModbusTcpClient @@ -46,36 +45,38 @@ def __init__( ) -> None: """Initialize Asyncio Modbus TCP Client.""" asyncio.Protocol.__init__(self) - ModbusBaseClient.__init__(self, framer=framer, **kwargs) + if "CommType" not in kwargs: + kwargs["CommType"] = CommType.TCP + ModbusBaseClient.__init__( + self, + framer=framer, + host=host, + port=port, + **kwargs, + ) self.params.host = host self.params.port = port self.params.source_address = source_address - if "internal_no_setup" in kwargs: - return - if host.startswith("unix:"): - self.new_transport.setup_unix(False, host[5:]) - else: - self.new_transport.setup_tcp(False, host, port) async def connect(self) -> bool: """Initiate connection to start client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again # so this instance will work - self.new_transport.reset_delay() + self.reset_delay() # force reconnect if required: Log.debug( "Connecting to {}:{}.", - self.new_transport.comm_params.host, - self.new_transport.comm_params.port, + self.comm_params.host, + self.comm_params.port, ) - return await self.new_transport.transport_connect() + return await self.transport_connect() @property def connected(self): """Return true if connected.""" - return self.new_transport.is_active() + return self.is_active() class ModbusTcpClient(ModbusBaseClient): @@ -87,8 +88,6 @@ class ModbusTcpClient(ModbusBaseClient): :param source_address: (optional) source address of client :param kwargs: (optional) Experimental parameters - using unix domain socket can be achieved by setting host="unix:" - Example:: from pymodbus.client import ModbusTcpClient @@ -129,16 +128,11 @@ def connect(self): # pylint: disable=invalid-overridden-method if self.socket: return True try: - if self.params.host.startswith("unix:"): - self.socket = socket.socket(socket.AF_UNIX) - self.socket.settimeout(self.params.timeout) - self.socket.connect(self.params.host[5:]) - else: - self.socket = socket.create_connection( - (self.params.host, self.params.port), - timeout=self.params.timeout, - source_address=self.params.source_address, - ) + self.socket = socket.create_connection( + (self.params.host, self.params.port), + timeout=self.params.timeout, + source_address=self.params.source_address, + ) Log.debug( "Connection to Modbus server established. Socket {}", self.socket.getsockname(), diff --git a/pymodbus/client/tls.py b/pymodbus/client/tls.py index 11a004989..07fde2539 100644 --- a/pymodbus/client/tls.py +++ b/pymodbus/client/tls.py @@ -7,35 +7,7 @@ from pymodbus.framer import ModbusFramer from pymodbus.framer.tls_framer import ModbusTlsFramer from pymodbus.logging import Log - - -def sslctx_provider( - sslctx=None, certfile=None, keyfile=None, password=None -): # pylint: disable=missing-type-doc - """Provide the SSLContext for ModbusTlsClient. - - If the user defined SSLContext is not passed in, sslctx_provider will - produce a default one. - - :param sslctx: The user defined SSLContext to use for TLS (default None and - auto create) - :param certfile: The optional client's cert file path for TLS server request - :param keyfile: The optional client's key file path for TLS server request - :param password: The password for decrypting client's private key file - """ - if sslctx: - return sslctx - - sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - sslctx.check_hostname = False - sslctx.verify_mode = ssl.CERT_NONE - sslctx.options |= ssl.OP_NO_TLSv1_1 - sslctx.options |= ssl.OP_NO_TLSv1 - sslctx.options |= ssl.OP_NO_SSLv3 - sslctx.options |= ssl.OP_NO_SSLv2 - if certfile and keyfile: - sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, password=password) - return sslctx +from pymodbus.transport.transport import CommParams, CommType class AsyncModbusTlsClient(AsyncModbusTcpClient): @@ -81,31 +53,32 @@ def __init__( ): """Initialize Asyncio Modbus TLS Client.""" AsyncModbusTcpClient.__init__( - self, host, port=port, framer=framer, internal_no_setup=True, **kwargs + self, + host, + port=port, + framer=framer, + CommType=CommType.TLS, + sslctx=CommParams.generate_ssl( + False, certfile, keyfile, password, sslctx=sslctx + ), + **kwargs, ) - self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) - self.params.certfile = certfile - self.params.keyfile = keyfile - self.params.password = password self.params.server_hostname = server_hostname - self.new_transport.setup_tls( - False, host, port, sslctx, certfile, keyfile, password, server_hostname - ) async def connect(self) -> bool: """Initiate connection to start client.""" # if reconnect_delay_current was set to 0 by close(), we need to set it back again # so this instance will work - self.new_transport.reset_delay() + self.reset_delay() # force reconnect if required: Log.debug( "Connecting to {}:{}.", - self.new_transport.comm_params.host, - self.new_transport.comm_params.port, + self.comm_params.host, + self.comm_params.port, ) - return await self.new_transport.transport_connect() + return await self.transport_connect() class ModbusTlsClient(ModbusTcpClient): @@ -145,7 +118,7 @@ def __init__( host: str, port: int = 802, framer: Type[ModbusFramer] = ModbusTlsFramer, - sslctx: str = None, + sslctx: ssl.SSLContext = None, certfile: str = None, keyfile: str = None, password: str = None, @@ -155,11 +128,9 @@ def __init__( """Initialize Modbus TLS Client.""" self.transport = None super().__init__(host, port=port, framer=framer, **kwargs) - self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password) - self.params.sslctx = sslctx - self.params.certfile = certfile - self.params.keyfile = keyfile - self.params.password = password + self.sslctx = CommParams.generate_ssl( + False, certfile, keyfile, password, sslctx=sslctx + ) self.params.server_hostname = server_hostname @property diff --git a/pymodbus/client/udp.py b/pymodbus/client/udp.py index 77d744b13..118146747 100644 --- a/pymodbus/client/udp.py +++ b/pymodbus/client/udp.py @@ -8,6 +8,7 @@ from pymodbus.framer import ModbusFramer from pymodbus.framer.socket_framer import ModbusSocketFramer from pymodbus.logging import Log +from pymodbus.transport.transport import CommType DGRAM_TYPE = socket.SOCK_DGRAM @@ -50,15 +51,16 @@ def __init__( """Initialize Asyncio Modbus UDP Client.""" asyncio.DatagramProtocol.__init__(self) asyncio.Protocol.__init__(self) - ModbusBaseClient.__init__(self, framer=framer, **kwargs) + ModbusBaseClient.__init__( + self, framer=framer, CommType=CommType.UDP, host=host, port=port, **kwargs + ) self.params.port = port self.params.source_address = source_address - self.new_transport.setup_udp(False, host, port) @property def connected(self): """Return true if connected.""" - return self.new_transport.is_active() + return self.is_active() async def connect(self) -> bool: """Start reconnecting asynchronous udp client. @@ -67,15 +69,15 @@ async def connect(self) -> bool: """ # if reconnect_delay_current was set to 0 by close(), we need to set it back again # so this instance will work - self.new_transport.reset_delay() + self.reset_delay() # force reconnect if required: Log.debug( "Connecting to {}:{}.", - self.new_transport.comm_params.host, - self.new_transport.comm_params.port, + self.comm_params.host, + self.comm_params.port, ) - return await self.new_transport.transport_connect() + return await self.transport_connect() class ModbusUdpClient(ModbusBaseClient): diff --git a/pymodbus/server/__init__.py b/pymodbus/server/__init__.py index 70793baa8..937f111de 100644 --- a/pymodbus/server/__init__.py +++ b/pymodbus/server/__init__.py @@ -9,14 +9,12 @@ "ModbusTcpServer", "ModbusTlsServer", "ModbusUdpServer", - "ModbusUnixServer", "ServerAsyncStop", "ServerStop", "StartAsyncSerialServer", "StartAsyncTcpServer", "StartAsyncTlsServer", "StartAsyncUdpServer", - "StartAsyncUnixServer", "StartSerialServer", "StartTcpServer", "StartTlsServer", @@ -28,14 +26,12 @@ ModbusTcpServer, ModbusTlsServer, ModbusUdpServer, - ModbusUnixServer, ServerAsyncStop, ServerStop, StartAsyncSerialServer, StartAsyncTcpServer, StartAsyncTlsServer, StartAsyncUdpServer, - StartAsyncUnixServer, StartSerialServer, StartTcpServer, StartTlsServer, diff --git a/pymodbus/server/async_io.py b/pymodbus/server/async_io.py index 71c2ce71e..85b79882e 100644 --- a/pymodbus/server/async_io.py +++ b/pymodbus/server/async_io.py @@ -1,7 +1,6 @@ """Implementation of a Threaded Modbus Server.""" # pylint: disable=missing-type-doc import asyncio -import ssl import time import traceback from contextlib import suppress @@ -11,6 +10,7 @@ from pymodbus.device import ModbusControlBlock, ModbusDeviceIdentification from pymodbus.exceptions import NoSuchSlaveException from pymodbus.factory import ServerDecoder +from pymodbus.framer import ModbusFramer from pymodbus.logging import Log from pymodbus.pdu import ModbusExceptions as merror from pymodbus.transaction import ( @@ -20,51 +20,19 @@ ModbusTlsFramer, ) from pymodbus.transport.serial_asyncio import create_serial_connection +from pymodbus.transport.transport import CommParams, CommType, Transport with suppress(ImportError): import serial -def sslctx_provider( - sslctx=None, certfile=None, keyfile=None, password=None, reqclicert=False -): - """Provide the SSLContext for ModbusTlsServer. - - If the user defined SSLContext is not passed in, sslctx_provider will - produce a default one. - - :param sslctx: The user defined SSLContext to use for TLS (default None and - auto create) - :param certfile: The cert file path for TLS (used if sslctx is None) - :param keyfile: The key file path for TLS (used if sslctx is None) - :param password: The password for for decrypting the private key file - :param reqclicert: Force the sever request client's certificate - """ - if sslctx is None: - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - sslctx.verify_mode = ssl.CERT_NONE - sslctx.check_hostname = False - sslctx.options |= ssl.OP_NO_TLSv1_1 - sslctx.options |= ssl.OP_NO_TLSv1 - sslctx.options |= ssl.OP_NO_SSLv3 - sslctx.options |= ssl.OP_NO_SSLv2 - sslctx.load_cert_chain(certfile=certfile, keyfile=keyfile, password=password) - - if reqclicert: - sslctx.verify_mode = ssl.CERT_REQUIRED - - return sslctx - - # --------------------------------------------------------------------------- # # Protocol Handlers # --------------------------------------------------------------------------- # -class ModbusServerRequestHandler(asyncio.BaseProtocol): +class ModbusServerRequestHandler(Transport): """Implements modbus slave wire protocol. This uses the asyncio.Protocol to implement the server protocol. @@ -78,45 +46,50 @@ class ModbusServerRequestHandler(asyncio.BaseProtocol): def __init__(self, owner): """Initialize.""" + params = CommParams( + comm_name="server", + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + ) + super().__init__(params, True) self.server = owner self.running = False self.receive_queue = asyncio.Queue() self.handler_task = None # coroutine to be run on asyncio loop self._sent = b"" # for handle_local_echo self.client_address = (None, None) + self.framer: ModbusFramer = None def _log_exception(self): """Show log exception.""" Log.debug("Handler for stream [{}] has been canceled", self.client_address) - def connection_made(self, transport): - """Call for socket establish - - For streamed protocols (TCP) this will also correspond to an - entire conversation; however for datagram protocols (UDP) this - corresponds to the socket being opened - """ + def callback_connected(self) -> None: + """Call when connection is succcesfull.""" try: if ( - hasattr(transport, "get_extra_info") - and transport.get_extra_info("peername") is not None + hasattr(self.transport, "get_extra_info") + and self.transport.get_extra_info("peername") is not None ): - self.client_address = transport.get_extra_info("peername")[:2] + self.client_address = self.transport.get_extra_info("peername")[:2] Log.debug("Peer [{}] opened", self.client_address) - elif hasattr(transport, "serial"): - Log.debug("Serial connection opened on port: {}", transport.serial.port) + elif hasattr(self.transport, "serial"): + Log.debug( + "Serial connection opened on port: {}", self.transport.serial.port + ) self.client_address = ("serial", "server") else: - Log.warning("Unable to get information about transport {}", transport) - self.transport = transport # pylint: disable=attribute-defined-outside-init - self.running = True - self.framer = ( # pylint: disable=attribute-defined-outside-init - self.server.framer( - self.server.decoder, - client=None, + Log.warning( + "Unable to get information about transport {}", self.transport ) + self.transport = self.transport + self.running = True + self.framer = self.server.framer( + self.server.decoder, + client=None, ) - self.server.active_connections[self.client_address] = self + self.server.local_active_connections[self.client_address] = self # schedule the connection handler on the event loop self.handler_task = asyncio.create_task(self.handle()) @@ -127,18 +100,13 @@ def connection_made(self, transport): traceback.format_exc(), ) - def connection_lost(self, call_exc): - """Call for socket tear down. - - For streamed protocols any break in the network connection will - be reported here; for datagram protocols, only a teardown of the - socket itself will result in this call. - """ + def callback_disconnected(self, call_exc: Exception) -> None: + """Call when connection is lost.""" try: if self.handler_task: self.handler_task.cancel() - if self.client_address in self.server.active_connections: - self.server.active_connections.pop(self.client_address) + if self.client_address in self.server.local_active_connections: + self.server.local_active_connections.pop(self.client_address) if hasattr(self.server, "on_connection_lost"): self.server.on_connection_lost() if call_exc is None: @@ -273,7 +241,9 @@ def send(self, message, *addr, **kwargs): def __send(msg, *addr): Log.debug("send: [{}]- {}", message, msg, ":b2a") if addr == (None,): - self._send_(msg) + self.transport.write(msg) + if self.server.handle_local_echo is True: + self._sent = msg else: self.transport.sendto(msg, *addr) @@ -286,20 +256,6 @@ def __send(msg, *addr): else: Log.debug("Skipping sending response!!") - # ----------------------------------------------------------------------- # - # Derived class implementations - # ----------------------------------------------------------------------- # - - def _send_(self, data): # pragma: no cover - """Send a request (string) to the network. - - :param data: The unencoded modbus response - :raises NotImplementedException: - """ - self.transport.write(data) - if self.server.handle_local_echo is True: - self._sent = data - async def _recv_(self): # pragma: no cover """Receive data from the network.""" try: @@ -309,8 +265,8 @@ async def _recv_(self): # pragma: no cover result = None return result - def data_received(self, data): - """Call when some data is received.""" + def callback_data(self, data: bytes, addr: tuple = None) -> int: + """Handle received data.""" if self.server.handle_local_echo is True and self._sent: if self._sent in data: data, self._sent = data.replace(self._sent, b"", 1), b"" @@ -319,29 +275,12 @@ def data_received(self, data): else: self._sent = b"" if not data: - return - self.receive_queue.put_nowait(data) - - def datagram_received(self, data, addr): - """Call when a datagram is received. - - data is a bytes object containing the incoming data. addr - is the address of the peer sending the data; the exact - format depends on the transport. - """ - self.receive_queue.put_nowait((data, addr)) - - def error_received(self, exc): # pragma: no cover - """Call when a previous send/receive raises an OSError. - - exc is the OSError instance. - - This method is called in rare conditions, - when the transport (e.g. UDP) detects that a datagram could - not be delivered to its recipient. In many conditions - though, undeliverable datagrams will be silently dropped. - """ - Log.error("datagram connection error [{}]", exc) + return 0 + if addr: + self.receive_queue.put_nowait((data, addr)) + else: + self.receive_queue.put_nowait(data) + return len(data) # --------------------------------------------------------------------------- # @@ -349,105 +288,7 @@ def error_received(self, exc): # pragma: no cover # --------------------------------------------------------------------------- # -class ModbusUnixServer: - """A modbus threaded Unix socket server. - - We inherit and overload the socket server so that we - can control the client threads as well as have a single - server context instance. - """ - - def __init__( - self, - context, - path, - framer=None, - identity=None, - **kwargs, - ): - """Initialize the socket server. - - If the identify structure is not passed in, the ModbusControlBlock - uses its own default structure. - - :param context: The ModbusServerContext datastore - :param path: unix socket path - :param framer: The framer strategy to use - :param identity: An optional identify structure - :param allow_reuse_address: Whether the server will allow the - reuse of an address. - :param ignore_missing_slaves: True to not send errors on a request - to a missing slave - :param broadcast_enable: True to treat slave_id 0 as broadcast address, - False to treat 0 as any other slave_id - :param response_manipulator: Callback method for manipulating the - response - """ - self.active_connections = {} - self.loop = kwargs.get("loop") or asyncio.get_event_loop() - self.decoder = ServerDecoder() - self.framer = framer or ModbusSocketFramer - self.context = context or ModbusServerContext() - self.control = ModbusControlBlock() - self.path = path - self.handler = ModbusServerRequestHandler - self.handler.server = self - self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) - self.broadcast_enable = kwargs.get("broadcast_enable", False) - self.response_manipulator = kwargs.get("response_manipulator", None) - if isinstance(identity, ModbusDeviceIdentification): - self.control.Identity.update(identity) - - # asyncio future that will be done once server has started - self.serving = asyncio.Future() - self.serving_done = asyncio.Future() - # constructors cannot be declared async, so we have to - # defer the initialization of the server - self.server = None - self.request_tracer = None - self.factory_parms = {} - self.handle_local_echo = False - - async def serve_forever(self): - """Start endless loop.""" - if self.server is None: - try: - self.server = await self.loop.create_unix_server( - lambda: self.handler(self), - self.path, - ) - self.serving.set_result(True) - Log.info("Server(Unix) listening.") - await self.server.serve_forever() - except asyncio.exceptions.CancelledError: - self.serving_done.set_result(True) - raise - except Exception as exc: # pylint: disable=broad-except - Log.error("Server unexpected exception {}", exc) - else: - raise RuntimeError( - "Can't call serve_forever on an already running server object" - ) - self.serving_done.set_result(True) - Log.info("Server graceful shutdown.") - - async def shutdown(self): - """Shutdown server.""" - await self.server_close() - - async def server_close(self): - """Close server.""" - for k_item, v_item in self.active_connections.items(): - Log.warning("aborting active session {}", k_item) - v_item.handler_task.cancel() - self.active_connections = {} - if self.server is not None: - self.server.close() - await self.server.wait_closed() - self.server = None - - -class ModbusTcpServer: +class ModbusTcpServer(Transport): """A modbus threaded tcp socket server. We inherit and overload the socket server so that we @@ -461,8 +302,6 @@ def __init__( framer=None, identity=None, address=None, - allow_reuse_address=False, - backlog=20, **kwargs, ): """Initialize the socket server. @@ -474,11 +313,6 @@ def __init__( :param framer: The framer strategy to use :param identity: An optional identify structure :param address: An optional (interface, port) to bind to. - :param allow_reuse_address: Whether the server will allow the - reuse of an address. - :param backlog: is the maximum number of queued connections - passed to listen(). increase if many - connections are being made and broken to your Modbus slave :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat slave_id 0 as broadcast address, @@ -486,16 +320,31 @@ def __init__( :param response_manipulator: Callback method for manipulating the response """ - self.active_connections = {} + if not address: + address = ("", 502) + params = kwargs.get( + "internal_tls_setup", + CommParams( + comm_type=CommType.TCP, + comm_name="server_listener", + host=address[0], + port=address[1], + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + ), + ) + super().__init__( + params, + True, + ) + self.local_active_connections = {} self.loop = kwargs.get("loop") or asyncio.get_event_loop() - self.allow_reuse_address = allow_reuse_address self.decoder = ServerDecoder() self.framer = framer or ModbusSocketFramer self.context = context or ModbusServerContext() self.control = ModbusControlBlock() - self.address = address or ("", 502) - self.handler = ModbusServerRequestHandler - self.handler.server = self + self.address = address self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) self.broadcast_enable = kwargs.get("broadcast_enable", False) self.response_manipulator = kwargs.get("response_manipulator", None) @@ -510,17 +359,26 @@ def __init__( # defer the initialization of the server self.server = None self.factory_parms = { - "reuse_address": allow_reuse_address, - "backlog": backlog, + "reuse_address": True, "start_serving": True, } + if params.sslctx: + self.factory_parms["ssl"] = params.sslctx self.handle_local_echo = False + def handle_new_connection(self): + """Handle incoming connect.""" + handler = ModbusServerRequestHandler + handler.server = self + return handler + async def serve_forever(self): """Start endless loop.""" if self.server is None: + handler = ModbusServerRequestHandler + handler.server = self self.server = await self.loop.create_server( - lambda: self.handler(self), + lambda: handler(self), *self.address, **self.factory_parms, ) @@ -546,14 +404,14 @@ async def shutdown(self): async def server_close(self): """Close server.""" - active_connecions = self.active_connections.copy() + active_connecions = self.local_active_connections.copy() for k_item, v_item in active_connecions.items(): Log.warning("aborting active session {}", k_item) v_item.transport.close() await asyncio.sleep(0.1) v_item.handler_task.cancel() await v_item.handler_task - self.active_connections = {} + self.local_active_connections = {} if self.server is not None: self.server.close() await self.server.wait_closed() @@ -568,7 +426,7 @@ class ModbusTlsServer(ModbusTcpServer): server context instance. """ - def __init__( # pylint: disable=too-many-arguments + def __init__( self, context, framer=None, @@ -578,9 +436,6 @@ def __init__( # pylint: disable=too-many-arguments certfile=None, keyfile=None, password=None, - reqclicert=False, - allow_reuse_address=False, - backlog=20, **kwargs, ): """Overloaded initializer for the socket server. @@ -597,12 +452,6 @@ def __init__( # pylint: disable=too-many-arguments :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) :param password: The password for for decrypting the private key file - :param reqclicert: Force the sever request client's certificate - :param allow_reuse_address: Whether the server will allow the - reuse of an address. - :param backlog: is the maximum number of queued connections - passed to listen(). increase if many - connections are being made and broken to your Modbus slave :param ignore_missing_slaves: True to not send errors on a request to a missing slave :param broadcast_enable: True to treat slave_id 0 as broadcast address, @@ -615,16 +464,21 @@ def __init__( # pylint: disable=too-many-arguments framer=framer, identity=identity, address=address, - allow_reuse_address=allow_reuse_address, - backlog=backlog, + internal_tls_setup=CommParams( + comm_type=CommType.TLS, + comm_name="server_listener", + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + sslctx=CommParams.generate_ssl( + True, certfile, keyfile, password, sslctx=sslctx + ), + ), **kwargs, ) - self.sslctx = sslctx_provider(sslctx, certfile, keyfile, password, reqclicert) - self.factory_parms["ssl"] = self.sslctx - self.handle_local_echo = False -class ModbusUdpServer: +class ModbusUdpServer(Transport): """A modbus threaded udp socket server. We inherit and overload the socket server so that we @@ -638,7 +492,6 @@ def __init__( framer=None, identity=None, address=None, - backlog=20, **kwargs, ): """Overloaded initializer for the socket server. @@ -659,17 +512,25 @@ def __init__( :param response_manipulator: Callback method for manipulating the response """ - # TO BE REMOVED: - self.backlog = backlog # ---------------- - self.active_connections = {} + super().__init__( + CommParams( + comm_type=CommType.UDP, + comm_name="server_listener", + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + ), + True, + ) + + self.local_active_connections = {} self.loop = asyncio.get_running_loop() self.decoder = ServerDecoder() self.framer = framer or ModbusSocketFramer self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.address = address or ("", 502) - self.handler = ModbusServerRequestHandler self.ignore_missing_slaves = kwargs.get("ignore_missing_slaves", False) self.broadcast_enable = kwargs.get("broadcast_enable", False) self.response_manipulator = kwargs.get("response_manipulator", None) @@ -694,8 +555,10 @@ async def serve_forever(self): """Start endless loop.""" if self.protocol is None: try: + handler = ModbusServerRequestHandler + handler.server = self self.protocol, self.endpoint = await self.loop.create_datagram_endpoint( - lambda: self.handler(self), + lambda: handler(self), **self.factory_parms, ) except asyncio.exceptions.CancelledError: @@ -732,7 +595,7 @@ async def server_close(self): self.protocol = None -class ModbusSerialServer: # pylint: disable=too-many-instance-attributes +class ModbusSerialServer(Transport): # pylint: disable=too-many-instance-attributes """A modbus threaded serial socket server. We inherit and overload the socket server so that we @@ -769,6 +632,17 @@ def __init__( :param response_manipulator: Callback method for manipulating the response """ + super().__init__( + CommParams( + comm_type=CommType.SERIAL, + comm_name="server_listener", + reconnect_delay=0.0, + reconnect_delay_max=0.0, + timeout_connect=0.0, + ), + True, + ) + self.loop = kwargs.get("loop") or asyncio.get_event_loop() self.bytesize = kwargs.get("bytesize", 8) self.parity = kwargs.get("parity", "N") @@ -782,7 +656,6 @@ def __init__( self.auto_reconnect = kwargs.get("auto_reconnect", False) self.reconnect_delay = kwargs.get("reconnect_delay", 2) self.reconnecting_task = None - self.handler = kwargs.get("handler") or ModbusServerRequestHandler self.framer = framer or ModbusRtuFramer self.decoder = ServerDecoder() self.context = context or ModbusServerContext() @@ -790,7 +663,7 @@ def __init__( self.control = ModbusControlBlock() if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) - self.active_connections = {} + self.local_active_connections = {} self.request_tracer = None self.protocol = None self.transport = None @@ -818,9 +691,11 @@ async def _connect(self): if self.device.startswith("socket:"): return try: + handler = ModbusServerRequestHandler + handler.server = self self.transport, self.protocol = await create_serial_connection( self.loop, - lambda: self.handler(self), + lambda: handler(self), self.device, baudrate=self.baudrate, bytesize=self.bytesize, @@ -850,15 +725,15 @@ async def shutdown(self): if self.transport: self.transport.abort() self.transport = None - loop_list = list(self.active_connections) + loop_list = list(self.local_active_connections) for k_item in loop_list: - v_item = self.active_connections[k_item] + v_item = self.local_active_connections[k_item] Log.warning("aborting active session {}", k_item) v_item.transport.close() await asyncio.sleep(0.1) v_item.handler_task.cancel() await v_item.handler_task - self.active_connections = {} + self.local_active_connections = {} if self.server: self.server.close() await asyncio.wait_for(self.server.wait_closed(), 10) @@ -889,12 +764,13 @@ async def serve_forever(self): # Socket server means listen so start a socket server parts = self.device[9:].split(":") host_addr = (parts[0], int(parts[1])) + handler = ModbusServerRequestHandler + handler.server = self self.server = await self.loop.create_server( - lambda: self.handler(self), + lambda: handler(self), *host_addr, reuse_address=True, start_serving=True, - backlog=20, ) try: await self.server.serve_forever() @@ -919,9 +795,7 @@ class _serverList: :meta private: """ - active_server: Union[ - ModbusUnixServer, ModbusTcpServer, ModbusUdpServer, ModbusSerialServer - ] = None + active_server: Union[ModbusTcpServer, ModbusUdpServer, ModbusSerialServer] = None def __init__(self, server): """Register new server.""" @@ -957,28 +831,6 @@ def stop(cls): time.sleep(10) -async def StartAsyncUnixServer( # pylint: disable=invalid-name,dangerous-default-value - context=None, - identity=None, - path=None, - custom_functions=[], - **kwargs, -): - """Start and run a tcp modbus server. - - :param context: The ModbusServerContext datastore - :param identity: An optional identify structure - :param path: An optional path to bind to. - :param custom_functions: An optional list of custom function classes - supported by server instance. - :param kwargs: The rest - """ - server = ModbusUnixServer( - context, path, kwargs.pop("framer", ModbusSocketFramer), identity, **kwargs - ) - await _serverList.run(server, custom_functions) - - async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default-value context=None, identity=None, @@ -1009,8 +861,6 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default certfile=None, keyfile=None, password=None, - reqclicert=False, - allow_reuse_address=False, custom_functions=[], **kwargs, ): @@ -1023,9 +873,6 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default :param certfile: The cert file path for TLS (used if sslctx is None) :param keyfile: The key file path for TLS (used if sslctx is None) :param password: The password for for decrypting the private key file - :param reqclicert: Force the sever request client's certificate - :param allow_reuse_address: Whether the server will allow the reuse of an - address. :param custom_functions: An optional list of custom function classes supported by server instance. :param kwargs: The rest @@ -1039,8 +886,6 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default certfile, keyfile, password, - reqclicert, - allow_reuse_address=allow_reuse_address, **kwargs, ) await _serverList.run(server, custom_functions) diff --git a/pymodbus/server/reactive/default_config.json b/pymodbus/server/reactive/default_config.json index 4fb88bf82..fea2e1fe4 100644 --- a/pymodbus/server/reactive/default_config.json +++ b/pymodbus/server/reactive/default_config.json @@ -1,8 +1,6 @@ { "tcp": { "handler": "ModbusConnectedRequestHandler", - "allow_reuse_address": true, - "backlog": 20, "ignore_missing_slaves": false }, "serial": { @@ -19,8 +17,6 @@ "handler": "ModbusConnectedRequestHandler", "certfile": null, "keyfile": null, - "allow_reuse_address": true, - "backlog": 20, "ignore_missing_slaves": false }, "udp": { diff --git a/pymodbus/server/reactive/default_config.py b/pymodbus/server/reactive/default_config.py index 0c861757a..bc7fdbb77 100644 --- a/pymodbus/server/reactive/default_config.py +++ b/pymodbus/server/reactive/default_config.py @@ -3,9 +3,7 @@ DEFAULT_CONFIG = { "tcp": { "handler": "ModbusConnectedRequestHandler", - "allow_reuse_address": True, "allow_reuse_port": True, - "backlog": 20, "ignore_missing_slaves": False, }, "serial": { @@ -22,9 +20,7 @@ "handler": "ModbusConnectedRequestHandler", "certfile": None, "keyfile": None, - "allow_reuse_address": True, "allow_reuse_port": True, - "backlog": 20, "ignore_missing_slaves": False, }, "udp": { diff --git a/pymodbus/server/simulator/setup.json b/pymodbus/server/simulator/setup.json index 74ae087bc..2f605d21b 100644 --- a/pymodbus/server/simulator/setup.json +++ b/pymodbus/server/simulator/setup.json @@ -4,7 +4,6 @@ "comm": "tcp", "host": "0.0.0.0", "port": 5020, - "allow_reuse_address": true, "ignore_missing_slaves": false, "framer": "socket", "identity": { @@ -42,8 +41,6 @@ "port": 5020, "certfile": "certificates/pymodbus.crt", "keyfile": "certificates/pymodbus.key", - "allow_reuse_address": true, - "backlog": 20, "ignore_missing_slaves": false, "framer": "tls", "identity": { diff --git a/pymodbus/transport/nullmodem.py b/pymodbus/transport/nullmodem.py deleted file mode 100644 index ea8685736..000000000 --- a/pymodbus/transport/nullmodem.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Null modem transport. - -This is a special transport, mostly thought of for testing. - -NullModem interconnect 2 transport objects and transfers calls: - - server.listen() - - dummy - - client.connect() - - call client.connection_made() - - call server.connection_made() - - client/server.close() - - call client.connection_lost() - - call server.connection_lost() - - server/client.send - - call client/server.data_received() - -""" -from __future__ import annotations - -import asyncio - -from pymodbus.logging import Log -from pymodbus.transport.transport import Transport - - -class DummyTransport(asyncio.BaseTransport): - """Use in connection_made calls.""" - - def close(self): - """Define dummy.""" - - def get_protocol(self): - """Define dummy.""" - - def is_closing(self): - """Define dummy.""" - - def set_protocol(self, _protocol): - """Define dummy.""" - - def abort(self): - """Define dummy.""" - - -class NullModem(Transport): - """Transport layer. - - Contains methods to act as a null modem between 2 objects. - (Allowing tests to be shortcut without actual network calls) - """ - - nullmodem_client: NullModem = None - nullmodem_server: NullModem = None - - def __init__(self, *arg): - """Overwrite init.""" - self.is_server: bool = False - self.other_end: NullModem = None - super().__init__(*arg) - - async def transport_connect(self) -> bool: - """Handle generic connect and call on to specific transport connect.""" - Log.debug("NullModem: Simulate connect on {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() - if self.nullmodem_server: - self.__class__.nullmodem_client = self - self.other_end = self.nullmodem_server - self.nullmodem_server.other_end = self - self.cb_connection_made() - self.other_end.cb_connection_made() - return True - return False - - async def transport_listen(self): - """Handle generic listen and call on to specific transport listen.""" - Log.debug("NullModem: Simulate listen on {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() - self.is_server = True - self.__class__.nullmodem_server = self - return DummyTransport() - - # -------------------------------- # - # Helper methods for child classes # - # -------------------------------- # - async def send(self, data: bytes) -> bool: - """Send request. - - :param data: non-empty bytes object with data to send. - """ - Log.debug("NullModem: simulate send {}", data, ":hex") - self.other_end.data_received(data) - return True - - def close(self, reconnect: bool = False) -> None: - """Close connection. - - :param reconnect: (default false), try to reconnect - """ - self.recv_buffer = b"" - if not reconnect: - if self.nullmodem_client: - self.nullmodem_client.cb_connection_lost(None) - if self.nullmodem_server: - self.nullmodem_server.cb_connection_lost(None) - self.__class__.nullmodem_client = None - self.__class__.nullmodem_server = None - - # ----------------- # - # The magic methods # - # ----------------- # - def __str__(self) -> str: - """Build a string representation of the connection.""" - return f"{self.__class__.__name__}({self.comm_params.comm_name})" diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index 89751fd06..e8ecc411c 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -4,242 +4,130 @@ from __future__ import annotations import asyncio +import dataclasses import ssl -import sys -from dataclasses import dataclass +from enum import Enum from typing import Any, Callable, Coroutine from pymodbus.logging import Log from pymodbus.transport.serial_asyncio import create_serial_connection -class Transport: - """Transport layer. +NULLMODEM_HOST = "__pymodbus_nullmodem" - Contains pure transport methods needed to connect/listen, send/receive and close connections - for unix socket, tcp, tls and serial communications. - Contains high level methods like reconnect. +class CommType(Enum): + """Type of transport""" - This class is not available in the pymodbus API, and should not be referenced in Applications - nor in the pymodbus documentation. + TCP = 1 + TLS = 2 + UDP = 3 + SERIAL = 4 - The class is designed to be an object in the message level class. - """ - @dataclass - class CommParamsClass: - """Parameter class.""" +@dataclasses.dataclass +class CommParams: + """Parameter class.""" + + # generic + comm_name: str = None + comm_type: CommType = None + reconnect_delay: float = None + reconnect_delay_max: float = None + timeout_connect: float = None + + # tcp / tls / udp / serial + host: str = None - # generic - done: bool = False - comm_name: str = None - reconnect_delay: float = None - reconnect_delay_max: float = None - timeout_connect: float = None + # tcp / tls / udp + port: int = None + + # tls + sslctx: ssl.SSLContext = None + + # serial + baudrate: int = None + bytesize: int = None + parity: str = None + stopbits: int = None + + @classmethod + def generate_ssl( + cls, + is_server: bool, + certfile: str = None, + keyfile: str = None, + password: str = None, + sslctx: ssl.SSLContext = None, + ) -> ssl.SSLContext: + """Generate sslctx from cert/key/passwor + + MODBUS/TCP Security Protocol Specification demands TLSv2 at least + """ + if sslctx: + return sslctx + new_sslctx = ssl.SSLContext( + ssl.PROTOCOL_TLS_SERVER if is_server else ssl.PROTOCOL_TLS_CLIENT + ) + new_sslctx.check_hostname = False + new_sslctx.verify_mode = ssl.CERT_NONE + new_sslctx.options |= ssl.OP_NO_TLSv1_1 + new_sslctx.options |= ssl.OP_NO_TLSv1 + new_sslctx.options |= ssl.OP_NO_SSLv3 + new_sslctx.options |= ssl.OP_NO_SSLv2 + if certfile: + new_sslctx.load_cert_chain( + certfile=certfile, keyfile=keyfile, password=password + ) + return new_sslctx - # tcp / tls / udp / serial - host: str = None - # tcp / tls / udp - port: int = None +class Transport(asyncio.BaseProtocol): + """Protocol layer including transport. - # tls - ssl: ssl.SSLContext = None - server_hostname: str = None + Contains pure transport methods needed to connect/listen, send/receive and close connections + for unix socket, tcp, tls and serial communications. - # serial - baudrate: int = None - bytesize: int = None - parity: str = None - stopbits: int = None + Contains high level methods like reconnect. - def check_done(self): - """Check if already setup""" - if self.done: - raise RuntimeError("Already setup!") - self.done = True + The class is designed to take care of differences between the different transport mediums, and + provide a neutral interface for the upper layers. + """ def __init__( self, - comm_name: str, - reconnect_delay: int, - reconnect_max: int, - timeout_connect: int, - callback_connected: Callable[[], None], - callback_disconnected: Callable[[Exception], None], - callback_data: Callable[[bytes], int], + params: CommParams, + is_server: bool, ) -> None: """Initialize a transport instance. - :param comm_name: name of this transport connection - :param reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect) - :param reconnect_delay: max reconnect delay in milliseconds - :param timeout_connect: Max. time in milliseconds for connect to complete + :param params: parameter dataclass + :param is_server: true if object act as a server (listen allowed) :param callback_connected: Called when connection is established :param callback_disconnected: Called when connection is disconnected :param callback_data: Called when data is received """ - self.cb_connection_made = callback_connected - self.cb_connection_lost = callback_disconnected - self.cb_handle_data = callback_data - - # properties, can be read, but may not be mingled with - self.comm_params = self.CommParamsClass( - comm_name=comm_name, - reconnect_delay=reconnect_delay / 1000, - reconnect_delay_max=reconnect_max / 1000, - timeout_connect=timeout_connect / 1000, - ) + self.comm_params = dataclasses.replace(params) + self.is_server = is_server self.reconnect_delay_current: float = 0.0 + self.listener: Transport = None self.transport: asyncio.BaseTransport | asyncio.Server = None self.loop: asyncio.AbstractEventLoop = None self.reconnect_task: asyncio.Task = None self.recv_buffer: bytes = b"" - self.call_connect_listen: Callable[[], Coroutine[Any, Any, Any]] = lambda: None - self.use_udp = False - - # ------------------------ # - # Transport specific setup # - # ------------------------ # - def setup_unix(self, setup_server: bool, host: str): - """Prepare transport unix""" - if sys.platform.startswith("win"): - raise RuntimeError("Modbus_unix is not supported on Windows!") - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - if setup_server: - self.call_connect_listen = lambda: self.loop.create_unix_server( - self.handle_listen, - path=self.comm_params.host, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_unix_connection( - lambda: self, - path=self.comm_params.host, - ) - - def setup_tcp(self, setup_server: bool, host: str, port: int): - """Prepare transport tcp""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - if setup_server: - self.call_connect_listen = lambda: self.loop.create_server( - self.handle_listen, - host=self.comm_params.host, - port=self.comm_params.port, - reuse_address=True, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_connection( - lambda: self, - host=self.comm_params.host, - port=self.comm_params.port, - ) - - def setup_tls( - self, - setup_server: bool, - host: str, - port: int, - sslctx: ssl.SSLContext, - certfile: str, - keyfile: str, - password: str, - server_hostname: str, - ): - """Prepare transport tls""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - self.comm_params.server_hostname = server_hostname - if not sslctx: - # According to MODBUS/TCP Security Protocol Specification, it is - # TLSv2 at least - sslctx = ssl.SSLContext( - ssl.PROTOCOL_TLS_SERVER if setup_server else ssl.PROTOCOL_TLS_CLIENT - ) - sslctx.check_hostname = False - sslctx.verify_mode = ssl.CERT_NONE - sslctx.options |= ssl.OP_NO_TLSv1_1 - sslctx.options |= ssl.OP_NO_TLSv1 - sslctx.options |= ssl.OP_NO_SSLv3 - sslctx.options |= ssl.OP_NO_SSLv2 - if certfile: - sslctx.load_cert_chain( - certfile=certfile, keyfile=keyfile, password=password - ) - self.comm_params.ssl = sslctx - if setup_server: - self.call_connect_listen = lambda: self.loop.create_server( - self.handle_listen, - host=self.comm_params.host, - port=self.comm_params.port, - reuse_address=True, - ssl=self.comm_params.ssl, - start_serving=True, - ) - else: - self.call_connect_listen = lambda: self.loop.create_connection( - lambda: self, - self.comm_params.host, - self.comm_params.port, - ssl=self.comm_params.ssl, - server_hostname=self.comm_params.server_hostname, - ) - - def setup_udp(self, setup_server: bool, host: str, port: int): - """Prepare transport udp""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.port = port - if setup_server: - - async def call_async_listen(self): - """Remove protocol return value.""" - transport, _protocol = await self.loop.create_datagram_endpoint( - self.handle_listen, - local_addr=(self.comm_params.host, self.comm_params.port), - ) - return transport - - self.call_connect_listen = lambda: call_async_listen(self) - else: - self.call_connect_listen = lambda: self.loop.create_datagram_endpoint( - lambda: self, - remote_addr=(self.comm_params.host, self.comm_params.port), - ) - self.use_udp = True - - def setup_serial( - self, - setup_server: bool, - host: str, - baudrate: int, - bytesize: int, - parity: str, - stopbits: int, - ): - """Prepare transport serial""" - self.comm_params.check_done() - self.comm_params.done = True - self.comm_params.host = host - self.comm_params.baudrate = baudrate - self.comm_params.bytesize = bytesize - self.comm_params.parity = parity - self.comm_params.stopbits = stopbits - if setup_server: - self.call_connect_listen = lambda: create_serial_connection( + self.call_create: Callable[[], Coroutine[Any, Any, Any]] = lambda: None + self.active_connections: dict[str, Transport] = {} + self.unique_id: str = str(id(self)) + + # Transport specific setup + if params.host == NULLMODEM_HOST: + self.call_create = self.create_nullmodem + return + if params.comm_type == CommType.SERIAL: + self.call_create = lambda: create_serial_connection( self.loop, - self.handle_listen, + self.handle_new_connection, self.comm_params.host, baudrate=self.comm_params.baudrate, bytesize=self.comm_params.bytesize, @@ -247,17 +135,35 @@ def setup_serial( stopbits=self.comm_params.stopbits, timeout=self.comm_params.timeout_connect, ) - + return + if params.comm_type == CommType.UDP: + if is_server: + self.call_create = lambda: self.loop.create_datagram_endpoint( + self.handle_new_connection, + local_addr=(self.comm_params.host, self.comm_params.port), + ) + else: + self.call_create = lambda: self.loop.create_datagram_endpoint( + self.handle_new_connection, + remote_addr=(self.comm_params.host, self.comm_params.port), + ) + return + # TLS and TCP + if is_server: + self.call_create = lambda: self.loop.create_server( + self.handle_new_connection, + self.comm_params.host, + self.comm_params.port, + ssl=self.comm_params.sslctx, + reuse_address=True, + start_serving=True, + ) else: - self.call_connect_listen = lambda: create_serial_connection( - self.loop, - lambda: self, + self.call_create = lambda: self.loop.create_connection( + self.handle_new_connection, self.comm_params.host, - baudrate=self.comm_params.baudrate, - bytesize=self.comm_params.bytesize, - stopbits=self.comm_params.stopbits, - parity=self.comm_params.parity, - timeout=self.comm_params.timeout_connect, + self.comm_params.port, + ssl=self.comm_params.sslctx, ) async def transport_connect(self) -> bool: @@ -265,10 +171,9 @@ async def transport_connect(self) -> bool: Log.debug("Connecting {}", self.comm_params.comm_name) if not self.loop: self.loop = asyncio.get_running_loop() - self.transport = None try: self.transport, _protocol = await asyncio.wait_for( - self.call_connect_listen(), + self.call_create(), timeout=self.comm_params.timeout_connect, ) except ( @@ -276,19 +181,22 @@ async def transport_connect(self) -> bool: OSError, ) as exc: Log.warning("Failed to connect {}", exc) - self.close(reconnect=True) + self.transport_close(reconnect=True) return False return bool(self.transport) - async def transport_listen(self): + async def transport_listen(self) -> bool: """Handle generic listen and call on to specific transport listen.""" Log.debug("Awaiting connections {}", self.comm_params.comm_name) try: - self.transport = await self.call_connect_listen() + self.transport = await self.call_create() + if isinstance(self.transport, tuple): + self.transport = self.transport[0] except OSError as exc: Log.warning("Failed to start server {}", exc) - self.close() - return self.transport + self.transport_close() + return False + return True # ---------------------------------- # # Transport asyncio standard methods # @@ -299,22 +207,22 @@ def connection_made(self, transport: asyncio.BaseTransport): :param transport: socket etc. representing the connection. """ Log.debug("Connected to {}", self.comm_params.comm_name) - if not self.loop: - self.loop = asyncio.get_running_loop() self.transport = transport self.reset_delay() - self.cb_connection_made() + self.callback_connected() def connection_lost(self, reason: Exception): """Call from asyncio, when the connection is lost or closed. :param reason: None or an exception object """ + if not self.transport: + return Log.debug("Connection lost {} due to {}", self.comm_params.comm_name, reason) - self.cb_connection_lost(reason) - if self.transport: - self.close() - self.reconnect_task = asyncio.create_task(self.reconnect_connect()) + self.transport_close() + if not self.is_server: + self.reconnect_task = asyncio.create_task(self.do_reconnect()) + self.callback_disconnected(reason) def data_received(self, data: bytes): """Call when some data is received. @@ -323,42 +231,60 @@ def data_received(self, data: bytes): """ Log.debug("recv: {}", data, ":hex") self.recv_buffer += data - cut = self.cb_handle_data(self.recv_buffer) + cut = self.callback_data(self.recv_buffer) self.recv_buffer = self.recv_buffer[cut:] - def datagram_received(self, data, _addr): + def datagram_received(self, data: bytes, addr: tuple): """Receive datagram (UDP connections).""" - self.data_received(data) + Log.debug("recv: {} addr={}", data, ":hex", addr) + self.recv_buffer += data + cut = self.callback_data(self.recv_buffer, addr=addr) + self.recv_buffer = self.recv_buffer[cut:] def eof_received(self): - """Accept other end terminates connection. - - Actual handling are in connection_lost() - """ + """Accept other end terminates connection.""" Log.debug("-> eof_received") def error_received(self, exc): - """Get error detected in UDP. - - Actual handling are in connection_lost() - """ + """Get error detected in UDP.""" Log.debug("-> error_received {}", exc) raise RuntimeError(str(exc)) + # --------- # + # callbacks # + # --------- # + def callback_connected(self) -> None: + """Call when connection is succcesfull.""" + Log.debug("callback_connected called") + + def callback_disconnected(self, exc: Exception) -> None: + """Call when connection is lost.""" + Log.debug("callback_disconnected called: {}", exc) + + def callback_data(self, data: bytes, addr: tuple = None) -> int: + """Handle received data.""" + Log.debug("callback_data called: {} addr={}", data, ":hex", addr) + return 0 + # ----------------------------------- # # Helper methods for external classes # # ----------------------------------- # - async def send(self, data: bytes) -> bool: + def transport_send(self, data: bytes, addr: tuple = None) -> None: """Send request. :param data: non-empty bytes object with data to send. + :param addr: optional addr, only used for UDP server. """ Log.debug("send: {}", data, ":hex") - if self.use_udp: - return self.transport.sendto(data) # type: ignore[union-attr] - return self.transport.write(data) # type: ignore[union-attr] + if self.comm_params.comm_type == CommType.UDP: + if addr: + self.transport.sendto(data, addr=addr) # type: ignore[union-attr] + else: + self.transport.sendto(data) # type: ignore[union-attr] + else: + self.transport.write(data) # type: ignore[union-attr] - def close(self, reconnect: bool = False) -> None: + def transport_close(self, reconnect: bool = False) -> None: """Close connection. :param reconnect: (default false), try to reconnect @@ -371,7 +297,15 @@ def close(self, reconnect: bool = False) -> None: if not reconnect and self.reconnect_task: self.reconnect_task.cancel() self.reconnect_task = None - self.recv_buffer = b"" + self.reconnect_delay_current = 0.0 + self.recv_buffer = b"" + if self.listener: + self.listener.active_connections.pop(self.unique_id) + elif self.is_server: + for _key, value in self.active_connections.items(): + value.listener = None + value.transport_close() + self.active_connections = {} def reset_delay(self) -> None: """Reset wait time before next reconnect to minimal period.""" @@ -384,11 +318,26 @@ def is_active(self) -> bool: # ---------------- # # Internal methods # # ---------------- # - def handle_listen(self): + async def create_nullmodem(self): + """Bypass create_ and use null modem""" + new_transport = NullModem(self.is_server, self) + new_protocol = self.handle_new_connection() + new_protocol.connection_made(new_transport) + if self.is_server: + return new_transport, new_protocol + return new_transport, new_protocol + + def handle_new_connection(self): """Handle incoming connect.""" - return self + if not self.is_server: + return self + + new_transport = Transport(self.comm_params, True) + new_transport.listener = self + self.active_connections[new_transport.unique_id] = new_transport + return new_transport - async def reconnect_connect(self): + async def do_reconnect(self): """Handle reconnect as a task.""" try: self.reconnect_delay_current = self.comm_params.reconnect_delay @@ -418,8 +367,87 @@ async def __aenter__(self): async def __aexit__(self, _class, _value, _traceback) -> None: """Implement the client with async exit block.""" - self.close() + self.transport_close() def __str__(self) -> str: """Build a string representation of the connection.""" return f"{self.__class__.__name__}({self.comm_params.comm_name})" + + +class NullModem(asyncio.DatagramTransport, asyncio.WriteTransport): + """Transport layer. + + Contains methods to act as a null modem between 2 objects. + (Allowing tests to be shortcut without actual network calls) + """ + + client: NullModem = None + server: NullModem = None + client_protocol: Transport = None + server_protocol: Transport = None + + def __init__(self, is_server: bool, protocol: Transport): + """Create half part of null modem""" + asyncio.DatagramTransport.__init__(self) + asyncio.WriteTransport.__init__(self) + self.other_protocol: Transport = None + if is_server: + self.__class__.server = self + self.__class__.server_protocol = protocol + return + if not self.__class__.server: + raise RuntimeError("Connect called before listen") + self.__class__.client = self + self.__class__.client_protocol = protocol + self.client.other_protocol = self.server_protocol + self.server.other_protocol = self.client_protocol + + # ---------------- # + # external methods # + # ---------------- # + + def close(self): + """Close null modem""" + + def sendto(self, data: bytes, _addr: Any = None): + """Send datagrame""" + return self.write(data) + + def write(self, data: bytes): + """Send data""" + return len(data) + + # ---------------- # + # Abstract methods # + # ---------------- # + def abort(self) -> None: + """Abort connection.""" + + def can_write_eof(self) -> bool: + """Allow to write eof""" + return True + + def get_write_buffer_size(self) -> int: + """Set write limit.""" + return 1024 + + def get_write_buffer_limits(self) -> tuple[int, int]: + """Set flush limits""" + return (1, 1024) + + def set_write_buffer_limits(self, high: int = None, low: int = None) -> None: + """Set flush limits""" + + def write_eof(self) -> None: + """Write eof""" + + def get_protocol(self) -> Transport: + """Return current protocol.""" + return None + + def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: + """Set current protocol.""" + + def is_closing(self) -> bool: + """Return true if closing""" + return False diff --git a/test/sub_examples/conftest.py b/test/sub_examples/conftest.py index 789840683..185b3c493 100644 --- a/test/sub_examples/conftest.py +++ b/test/sub_examples/conftest.py @@ -7,14 +7,21 @@ from pymodbus.server import ServerAsyncStop +@pytest_asyncio.fixture(name="port_offset") +def _define_port_offset(): + """Define port offset""" + return 0 + + @pytest_asyncio.fixture(name="mock_cmdline") def _define_commandline( use_comm, use_framer, use_port, + port_offset, ): """Define commandline.""" - my_port = str(use_port) + my_port = str(use_port + port_offset) cmdline = [ "--comm", use_comm, diff --git a/test/sub_examples/test_client_server_async.py b/test/sub_examples/test_client_server_async.py index d241aab11..e192310da 100755 --- a/test/sub_examples/test_client_server_async.py +++ b/test/sub_examples/test_client_server_async.py @@ -14,23 +14,26 @@ from examples.client_async import run_a_few_calls, run_async_client, setup_async_client +BASE_PORT = 6200 + + class TestClientServerAsyncExamples: """Test Client server async examples.""" USE_CASES = [ - ("tcp", "socket"), - ("tcp", "rtu"), - ("tls", "tls"), - ("udp", "socket"), - ("udp", "rtu"), - ("serial", "rtu"), - # awaiting fix: ("serial", "ascii"), - # awaiting fix: ("serial", "binary"), + ("tcp", "socket", BASE_PORT + 1), + ("tcp", "rtu", BASE_PORT + 2), + ("tls", "tls", BASE_PORT + 3), + ("udp", "socket", BASE_PORT + 4), + ("udp", "rtu", BASE_PORT + 5), + ("serial", "rtu", BASE_PORT + 6), + # awaiting fix: ("serial", "ascii", BASE_PORT + 7), + # awaiting fix: ("serial", "binary", BASE_PORT + 8), ] - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [0]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) async def test_combinations(self, mock_server): @@ -39,18 +42,18 @@ async def test_combinations(self, mock_server): test_client = setup_async_client(cmdline=cmdline) await run_async_client(test_client, modbus_calls=run_a_few_calls) - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [10]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) async def test_server_no_client(self, mock_server): """Run async server without client.""" assert mock_server - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [20]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) async def test_server_client_twice(self, mock_server): @@ -61,9 +64,9 @@ async def test_server_client_twice(self, mock_server): await asyncio.sleep(0.5) await run_async_client(test_client, modbus_calls=run_a_few_calls) - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [30]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) async def test_client_no_server(self, mock_cmdline): diff --git a/test/sub_examples/test_client_server_sync.py b/test/sub_examples/test_client_server_sync.py index 61f278552..8d133f94f 100755 --- a/test/sub_examples/test_client_server_sync.py +++ b/test/sub_examples/test_client_server_sync.py @@ -19,23 +19,26 @@ from pymodbus.server import ServerStop +BASE_PORT = 6300 + + class TestClientServerSyncExamples: """Test Client server async combinations.""" USE_CASES = [ - ("tcp", "socket"), - ("tcp", "rtu"), - # awaiting fix: ("tls", "tls"), - ("udp", "socket"), - ("udp", "rtu"), - ("serial", "rtu"), - # awaiting fix: ("serial", "ascii"), - # awaiting fix: ("serial", "binary"), + ("tcp", "socket", BASE_PORT + 1), + ("tcp", "rtu", BASE_PORT + 2), + # awaiting fix: ("tls", "tls", BASE_PORT + 3), + ("udp", "socket", BASE_PORT + 4), + ("udp", "rtu", BASE_PORT + 5), + ("serial", "rtu", BASE_PORT + 6), + # awaiting fix: ("serial", "ascii", BASE_PORT + 7), + # awaiting fix: ("serial", "binary", BASE_PORT + 8), ] - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [0]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) def test_combinations( @@ -52,9 +55,9 @@ def test_combinations( run_sync_client(test_client, modbus_calls=run_a_few_calls) ServerStop() - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [10]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) def test_server_no_client(self, mock_cmdline): @@ -66,9 +69,9 @@ def test_server_no_client(self, mock_cmdline): sleep(1) ServerStop() - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [20]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) def test_server_client_twice(self, mock_cmdline): @@ -84,9 +87,9 @@ def test_server_client_twice(self, mock_cmdline): run_sync_client(test_client, modbus_calls=run_a_few_calls) ServerStop() - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("port_offset", [30]) @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) def test_client_no_server(self, mock_cmdline): diff --git a/test/sub_examples/test_examples.py b/test/sub_examples/test_examples.py index 2dd981917..836045189 100755 --- a/test/sub_examples/test_examples.py +++ b/test/sub_examples/test_examples.py @@ -24,18 +24,21 @@ from pymodbus.server import ServerAsyncStop +BASE_PORT = 6400 + + class TestExamples: """Test examples.""" USE_CASES = [ - ("tcp", "socket"), - ("tcp", "rtu"), - ("tls", "tls"), - ("udp", "socket"), - ("udp", "rtu"), - ("serial", "rtu"), - # awaiting fix: ("serial", "ascii"), - # awaiting fix: ("serial", "binary"), + ("tcp", "socket", BASE_PORT + 1), + ("tcp", "rtu", BASE_PORT + 2), + ("tls", "tls", BASE_PORT + 3), + ("udp", "socket", BASE_PORT + 4), + ("udp", "rtu", BASE_PORT + 5), + ("serial", "rtu", BASE_PORT + 6), + # awaiting fix: ("serial", "ascii", BASE_PORT + 7), + # awaiting fix: ("serial", "binary", BASE_PORT + 8), ] def test_build_bcd_payload(self): @@ -54,9 +57,8 @@ def test_message_parser(self): parse_messages(["--framer", "socket", "-m", "000100000006010100200001"]) parse_messages(["--framer", "socket", "-m", "00010000000401010101"]) - @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), USE_CASES, ) async def test_client_calls(self, mock_server): @@ -65,11 +67,10 @@ async def test_client_calls(self, mock_server): test_client = setup_async_client(cmdline=cmdline) await run_async_client(test_client, modbus_calls=run_async_calls) - @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), [ - ("tcp", "socket"), + ("tcp", "socket", BASE_PORT + 41), ], ) def test_custom_msg(self, use_port, mock_server): @@ -77,11 +78,10 @@ def test_custom_msg(self, use_port, mock_server): _cmdline = mock_server run_custom_client("localhost", use_port) - @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), [ - ("tcp", "socket"), + ("tcp", "socket", BASE_PORT + 42), ], ) async def test_payload(self, mock_cmdline): @@ -97,7 +97,7 @@ async def test_payload(self, mock_cmdline): task.cancel() await task - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("use_port", [BASE_PORT + 43]) async def test_datastore_simulator(self, use_port): """Test server simulator.""" cmdargs = ["--port", str(use_port)] @@ -113,7 +113,7 @@ async def test_datastore_simulator(self, use_port): task.cancel() await task - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("use_port", [BASE_PORT + 44]) async def test_server_callback(self, use_port): """Test server/client with payload.""" cmdargs = ["--port", str(use_port)] @@ -127,7 +127,7 @@ async def test_server_callback(self, use_port): task.cancel() await task - @pytest.mark.xdist_group(name="server_serialize") + @pytest.mark.parametrize("use_port", [BASE_PORT + 45]) async def test_updating_server(self, use_port): """Test server simulator.""" cmdargs = ["--port", str(use_port)] @@ -142,11 +142,10 @@ async def test_updating_server(self, use_port): task.cancel() await task - @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), [ - ("tcp", "socket"), + ("tcp", "socket", BASE_PORT + 46), ], ) async def test_simple_async_client(self, use_port, mock_server): @@ -154,11 +153,10 @@ async def test_simple_async_client(self, use_port, mock_server): _cmdline = mock_server await run_simple_async_client("127.0.0.1", str(use_port)) - @pytest.mark.xdist_group(name="server_serialize") @pytest.mark.parametrize( - ("use_comm", "use_framer"), + ("use_comm", "use_framer", "use_port"), [ - ("tcp", "socket"), + ("tcp", "socket", BASE_PORT + 47), ], ) async def test_simple_sync_client(self, use_port, mock_server): diff --git a/test/sub_transport/conftest.py b/test/sub_transport/conftest.py index 29bfdbdc7..69a00e710 100644 --- a/test/sub_transport/conftest.py +++ b/test/sub_transport/conftest.py @@ -1,139 +1,132 @@ """Fixtures for transport tests.""" import asyncio +import dataclasses import os -import time from contextlib import suppress -from dataclasses import dataclass -from tempfile import gettempdir from unittest import mock import pytest -import pytest_asyncio -from pymodbus.transport.nullmodem import NullModem -from pymodbus.transport.transport import Transport +from pymodbus.transport.transport import CommParams, CommType, NullModem, Transport -@dataclass -class BaseParams(Transport.CommParamsClass): - """Base parameters for all transport testing.""" +class DummyTransport(asyncio.BaseTransport): + """Use in connection_made calls.""" - comm_name = "test comm" - reconnect_delay = 1000 - reconnect_delay_max = 3500 - timeout_connect = 2000 - host = "test host" - port = 502 - server_hostname = "server test host" - baudrate = 9600 - bytesize = 8 - parity = "e" - stopbits = 2 - cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + def transport_close(self): + """Define dummy.""" + def transport_send(self): + """Define dummy.""" -@pytest.fixture(name="params") -def prepare_baseparams(use_port): - """Prepare BaseParams class.""" - BaseParams.port = use_port - return BaseParams + def close(self): + """Define dummy.""" + def get_protocol(self): + """Define dummy.""" -@pytest.fixture(name="commparams") -def prepare_testparams(): - """Prepare CommParamsClass object.""" - return Transport.CommParamsClass( - done=True, - comm_name=BaseParams.comm_name, - reconnect_delay=BaseParams.reconnect_delay / 1000, - reconnect_delay_max=BaseParams.reconnect_delay_max / 1000, - timeout_connect=BaseParams.timeout_connect / 1000, - ) + def is_closing(self): + """Define dummy.""" + def set_protocol(self, _protocol): + """Define dummy.""" -@pytest.fixture(name="transport") -async def prepare_transport(): - """Prepare transport object.""" - transport = Transport( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) - with suppress(RuntimeError): - transport.loop = asyncio.get_running_loop() - return transport + def abort(self): + """Define dummy.""" -@pytest.fixture(name="nullmodem") -async def prepare_nullmodem(): - """Prepare nullmodem object.""" - transport = NullModem( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) - transport.__class__.nullmodem_client = None - transport.__class__.nullmodem_server = None - with suppress(RuntimeError): - transport.loop = asyncio.get_running_loop() - return transport +@pytest.fixture(name="dummy_transport") +def prepare_dummy_transport(): + """Return transport object""" + return DummyTransport() -@pytest.fixture(name="nullmodem_server") -async def prepare_nullmodem_server(): - """Prepare nullmodem object.""" - transport = NullModem( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), +@pytest.fixture(name="cwd_certificate") +def prepare_cwd_certificate(): + """Prepare path to certificate.""" + return os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + + +@pytest.fixture(name="use_comm_type") +def prepare_dummy_use_comm_type(): + """Return default comm_type""" + return CommType.TCP + + +@pytest.fixture(name="use_host") +def prepare_dummy_use_host(): + """Return default host""" + return "localhost" + + +@pytest.fixture(name="commparams") +def prepare_commparams(use_port, use_host, use_comm_type): + """Prepare CommParamsClass object.""" + return CommParams( + comm_name="test comm", + comm_type=use_comm_type, + reconnect_delay=1, + reconnect_delay_max=3.5, + timeout_connect=2, + host=use_host, + port=use_port, + baudrate=9600, + bytesize=8, + parity="E", + stopbits=2, ) - transport.__class__.nullmodem_client = None - transport.__class__.nullmodem_server = None + + +@pytest.fixture(name="client") +async def prepare_transport(commparams): + """Prepare transport object.""" + transport = Transport(commparams, False) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) + if commparams.comm_type == CommType.TLS: + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + transport.comm_params.sslctx = commparams.generate_ssl( + False, certfile=cwd + "crt", keyfile=cwd + "key" + ) + if commparams.comm_type == CommType.SERIAL: + transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}" return transport -@pytest_asyncio.fixture(name="transport_server") -async def prepare_transport_server(): +@pytest.fixture(name="server") +async def prepare_transport_server(commparams): """Prepare transport object.""" - transport = Transport( - BaseParams.comm_name, - BaseParams.reconnect_delay, - BaseParams.reconnect_delay_max, - BaseParams.timeout_connect, - mock.Mock(name="cb_connection_made"), - mock.Mock(name="cb_connection_lost"), - mock.Mock(name="cb_handle_data", return_value=0), - ) + if commparams.comm_type == CommType.SERIAL: + commparams = dataclasses.replace(commparams) + commparams.comm_type = CommType.TCP + transport = Transport(commparams, True) with suppress(RuntimeError): transport.loop = asyncio.get_running_loop() + transport.callback_connected = mock.Mock() + transport.callback_disconnected = mock.Mock() + transport.callback_data = mock.Mock(return_value=0) + if commparams.comm_type == CommType.TLS: + cwd = os.path.dirname(__file__) + "/../../examples/certificates/pymodbus." + transport.comm_params.sslctx = commparams.generate_ssl( + True, certfile=cwd + "crt", keyfile=cwd + "key" + ) + elif commparams.comm_type == CommType.SERIAL: + serial_params = dataclasses.replace(commparams) + serial_params.comm_type = CommType.TCP + transport = Transport(serial_params, True) return transport -@pytest.fixture(name="domain_host") -def get_domain_host(positive): - """Get test host.""" - return "localhost" if positive else "/illegal_host_name" +@pytest.fixture(name="nullmodem") +def prepare_nullmodem(): + """Prepare nullmodem object.""" + return NullModem(False, mock.Mock()) -@pytest.fixture(name="domain_socket") -def get_domain_socket(positive): - """Get test file.""" - return ( - gettempdir() + "/test_unix_" + str(time.time()) - if positive - else "/illegal_file_name" - ) +@pytest.fixture(name="nullmodem_server") +def prepare_nullmodem_server(): + """Prepare nullmodem object.""" + return NullModem(True, mock.Mock()) diff --git a/test/sub_transport/test_basic.py b/test/sub_transport/test_basic.py index 3535b954e..91eca306f 100644 --- a/test/sub_transport/test_basic.py +++ b/test/sub_transport/test_basic.py @@ -3,458 +3,219 @@ from unittest import mock import pytest -from serial import SerialException -from pymodbus.transport.nullmodem import DummyTransport +from pymodbus.transport.transport import NULLMODEM_HOST, CommType, NullModem, Transport -class TestBasicTransport: - """Test transport module, base part.""" - - async def test_init(self, transport, commparams): - """Test init()""" - commparams.done = False - assert transport.comm_params == commparams - assert ( - transport.cb_connection_made._extract_mock_name() # pylint: disable=protected-access - == "cb_connection_made" - ) - assert ( - transport.cb_connection_lost._extract_mock_name() # pylint: disable=protected-access - == "cb_connection_lost" - ) - assert ( - transport.cb_handle_data._extract_mock_name() # pylint: disable=protected-access - == "cb_handle_data" - ) - assert not transport.reconnect_delay_current - assert not transport.reconnect_task +COMM_TYPES = [ + CommType.TCP, + CommType.TLS, + CommType.UDP, + CommType.SERIAL, +] - async def test_property_done(self, transport): - """Test done property""" - transport.comm_params.check_done() - with pytest.raises(RuntimeError): - transport.comm_params.check_done() - async def test_with_magic(self, transport): - """Test magic.""" - transport.close = mock.MagicMock() - async with transport: - pass - transport.close.assert_called_once() +class TestBasicTransport: + """Test transport module.""" - async def test_str_magic(self, params, transport): - """Test magic.""" - assert str(transport) == f"Transport({params.comm_name})" + @pytest.mark.parametrize("use_comm_type", COMM_TYPES) + async def test_init(self, client, server, commparams): + """Test init()""" + if commparams.comm_type == CommType.SERIAL: + client.comm_params.host = commparams.host + server.comm_params.comm_type = commparams.comm_type + client.comm_params.sslctx = None + assert client.comm_params == commparams + assert client.unique_id == str(id(client)) + assert not client.is_server + server.comm_params.sslctx = None + assert server.comm_params == commparams + assert server.unique_id == str(id(server)) + assert server.is_server + + commparams.host = NULLMODEM_HOST + Transport(commparams, False) + + async def test_connect(self, client, dummy_transport): + """Test properties.""" + client.loop = None + client.call_create = mock.AsyncMock(return_value=(dummy_transport, None)) + assert await client.transport_connect() + assert client.loop + client.call_create.side_effect = asyncio.TimeoutError("test") + assert not await client.transport_connect() + + async def test_listen(self, server, dummy_transport): + """Test listen_tcp().""" + server.call_create = mock.AsyncMock(return_value=(dummy_transport, None)) + assert await server.transport_listen() + server.call_create.side_effect = OSError("testing") + assert not await server.transport_listen() - async def test_connection_made(self, transport, commparams): + async def test_connection_made(self, client, commparams, dummy_transport): """Test connection_made().""" - transport.loop = None - transport.connection_made(DummyTransport()) - assert transport.transport - assert not transport.recv_buffer - assert not transport.reconnect_task - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.cb_connection_made.assert_called_once() - transport.cb_connection_lost.assert_not_called() - transport.cb_handle_data.assert_not_called() - transport.close() - - async def test_connection_lost(self, transport): + client.connection_made(dummy_transport) + assert client.transport + assert not client.recv_buffer + assert not client.reconnect_task + assert client.reconnect_delay_current == commparams.reconnect_delay + client.callback_connected.assert_called_once() + + async def test_connection_lost(self, client, dummy_transport): """Test connection_lost().""" - transport.connection_lost(RuntimeError("not implemented")) - assert not transport.transport - assert not transport.recv_buffer - assert not transport.reconnect_task - assert not transport.reconnect_delay_current - transport.cb_connection_made.assert_not_called() - transport.cb_handle_data.assert_not_called() - transport.cb_connection_lost.assert_called_once() - - transport.transport = mock.Mock() - transport.connection_lost(RuntimeError("not implemented")) - assert not transport.transport - assert transport.reconnect_task - transport.close() - assert not transport.reconnect_task - - async def test_close(self, transport): - """Test close().""" - socket = DummyTransport() - socket.abort = mock.Mock() - socket.close = mock.Mock() - transport.connection_made(socket) - transport.cb_connection_made.reset_mock() - transport.cb_connection_lost.reset_mock() - transport.cb_handle_data.reset_mock() - transport.recv_buffer = b"abc" - transport.reconnect_task = mock.MagicMock() - transport.close() - socket.abort.assert_called_once() - socket.close.assert_called_once() - transport.cb_connection_made.assert_not_called() - transport.cb_connection_lost.assert_not_called() - transport.cb_handle_data.assert_not_called() - assert not transport.recv_buffer - assert not transport.reconnect_task - - async def test_reset_delay(self, transport, commparams): - """Test reset_delay().""" - transport.reconnect_delay_current += 5.17 - transport.reset_delay() - assert transport.reconnect_delay_current == commparams.reconnect_delay - - async def test_datagram(self, transport): + client.connection_lost(RuntimeError("not implemented")) + client.connection_made(dummy_transport) + client.connection_lost(RuntimeError("not implemented")) + assert not client.transport + assert not client.recv_buffer + assert client.reconnect_task + client.callback_disconnected.assert_called_once() + client.transport_close() + assert not client.reconnect_task + assert not client.reconnect_delay_current + + async def test_data_received(self, client): + """Test data_received.""" + client.callback_data = mock.MagicMock(return_value=2) + client.data_received(b"123456") + client.callback_data.assert_called_once() + assert client.recv_buffer == b"3456" + client.data_received(b"789") + assert client.recv_buffer == b"56789" + + async def test_datagram(self, client): """Test datagram_received().""" - transport.data_received = mock.MagicMock() - transport.datagram_received(b"abc", "127.0.0.1") - transport.data_received.assert_called_once() + client.callback_data = mock.MagicMock() + client.datagram_received(b"abc", "127.0.0.1") + client.callback_data.assert_called_once() - async def test_data(self, transport): - """Test data_received.""" - transport.cb_handle_data = mock.MagicMock(return_value=2) - transport.data_received(b"123456") - transport.cb_handle_data.assert_called_once() - assert transport.recv_buffer == b"3456" - transport.data_received(b"789") - assert transport.recv_buffer == b"56789" - - async def test_eof_received(self, transport): + async def test_eof_received(self, client): """Test eof_received.""" - transport.eof_received() + client.eof_received() - async def test_error_received(self, transport): + async def test_error_received(self, client): """Test error_received.""" with pytest.raises(RuntimeError): - transport.error_received(Exception("test call")) - - async def test_send(self, transport, params): - """Test send().""" - transport.transport = mock.AsyncMock() - await transport.send(b"abc") - - transport.setup_udp(False, params.host, params.port) - await transport.send(b"abc") - transport.close() - - async def test_handle_listen(self, transport): - """Test handle_listen().""" - assert transport == transport.handle_listen() - - async def test_no_loop(self, transport): - """Test properties.""" - transport.loop = None - transport.call_connect_listen = mock.AsyncMock(return_value=(117, 118)) - await transport.transport_connect() - assert transport.loop - - async def test_reconnect_connect(self, transport): - """Test handle_listen().""" - transport.comm_params.reconnect_delay = 0.01 - transport.transport_connect = mock.AsyncMock(side_effect=[False, True]) - await transport.reconnect_connect() - assert ( - transport.reconnect_delay_current - == transport.comm_params.reconnect_delay * 2 - ) - assert not transport.reconnect_task - transport.transport_connect = mock.AsyncMock( - side_effect=asyncio.CancelledError("stop loop") - ) - await transport.reconnect_connect() - assert ( - transport.reconnect_delay_current == transport.comm_params.reconnect_delay - ) - assert not transport.reconnect_task - - -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented") -class TestBasicUnixTransport: - """Test transport module, unix part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_unix(setup_server, params.host) - commparams.host = params.host - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties_windows(self, params, setup_server, transport): - """Test properties.""" - with mock.patch( - "pymodbus.transport.transport.sys.platform", return_value="windows" - ), pytest.raises(RuntimeError): - transport.setup_unix(setup_server, params.host) - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_unix().""" - transport.setup_unix(False, params.host) - mocker = mock.AsyncMock() - transport.loop.create_unix_connection = mocker - mocker.side_effect = FileNotFoundError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_unix().""" - transport.setup_unix(True, params.host) - mocker = mock.AsyncMock() - transport.loop.create_unix_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() - - -class TestBasicTcpTransport: - """Test transport module, tcp part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_tcp(setup_server, params.host, params.port) - commparams.host = params.host - commparams.port = params.port - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_tcp().""" - transport.setup_tcp(False, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_tcp().""" - transport.setup_tcp(True, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_is_active(self, params, transport): - """Test properties.""" - transport.setup_tcp(False, params.host, params.port) - assert not transport.is_active() - transport.connection_made(mock.AsyncMock()) - assert transport.is_active() - transport.close() - + client.error_received(Exception("test call")) + + async def test_callbacks(self, commparams): + """Test callbacks.""" + client = Transport(commparams, False) + client.callback_connected() + client.callback_disconnected(Exception("test")) + client.callback_data(b"abcd") + + async def test_transport_send(self, client): + """Test transport_send().""" + client.transport = mock.AsyncMock() + client.transport_send(b"abc") + + client.comm_params.comm_type = CommType.UDP + client.transport_send(b"abc") + client.transport_send(b"abc", addr=("localhost", 502)) + + async def test_transport_close(self, server, dummy_transport): + """Test transport_close().""" + dummy_transport.abort = mock.Mock() + dummy_transport.close = mock.Mock() + server.connection_made(dummy_transport) + server.recv_buffer = b"abc" + server.reconnect_task = mock.MagicMock() + server.listener = mock.MagicMock() + server.transport_close() + dummy_transport.abort.assert_called_once() + dummy_transport.close.assert_called_once() + assert not server.recv_buffer + assert not server.reconnect_task + server.listener = None + server.active_connections = {"a": dummy_transport} + server.transport_close() + assert not server.active_connections + + async def test_reset_delay(self, client, commparams): + """Test reset_delay().""" + client.reconnect_delay_current += 5.17 + client.reset_delay() + assert client.reconnect_delay_current == commparams.reconnect_delay + + async def test_is_active(self, client): + """Test is_active().""" + assert not client.is_active() + client.connection_made(mock.AsyncMock()) + assert client.is_active() + + @pytest.mark.parametrize("use_host", [NULLMODEM_HOST]) + async def test_create_nullmodem(self, client, server): + """Test create_nullmodem.""" + await server.transport_listen() + await client.transport_listen() + + async def test_handle_new_connection(self, client, server): + """Test handle_new_connection().""" + server.handle_new_connection() + client.handle_new_connection() + + async def test_do_reconnect(self, client): + """Test do_reconnect().""" + client.comm_params.reconnect_delay = 0.01 + client.transport_connect = mock.AsyncMock(side_effect=[False, True]) + await client.do_reconnect() + assert client.reconnect_delay_current == client.comm_params.reconnect_delay * 2 + assert not client.reconnect_task + client.transport_connect.side_effect = asyncio.CancelledError("stop loop") + await client.do_reconnect() + assert client.reconnect_delay_current == client.comm_params.reconnect_delay + assert not client.reconnect_task + + async def test_with_magic(self, client): + """Test magic.""" + client.transport_close = mock.MagicMock() + async with client: + pass + client.transport_close.assert_called_once() -class TestBasicTlsTransport: - """Test transport module, tls part.""" + async def test_str_magic(self, commparams, client): + """Test magic.""" + assert str(client) == f"Transport({commparams.comm_name})" - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - @pytest.mark.parametrize("sslctx", [None, "test ctx"]) - def test_properties(self, setup_server, sslctx, params, transport, commparams): - """Test properties.""" + def test_generate_ssl(self, commparams): + """Test ssl generattion""" with mock.patch("pymodbus.transport.transport.ssl.SSLContext"): - transport.setup_tls( - setup_server, - params.host, - params.port, - sslctx, - "certfile dummy", - None, - None, - params.server_hostname, - ) - commparams.host = params.host - commparams.port = params.port - commparams.server_hostname = params.server_hostname - commparams.ssl = sslctx if sslctx else transport.comm_params.ssl - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_tcls().""" - transport.setup_tls( - False, - params.host, - params.port, - "no ssl", - None, - None, - None, - params.server_hostname, - ) - mocker = mock.AsyncMock() - transport.loop.create_connection = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_tls().""" - transport.setup_tls( - True, - params.host, - params.port, - "no ssl", - None, - None, - None, - params.server_hostname, + sslctx = commparams.generate_ssl(True, "cert_file", "key_file") + assert sslctx + test_value = "test igen" + assert test_value == commparams.generate_ssl( + True, "cert_file", "key_file", sslctx=test_value ) - mocker = mock.AsyncMock() - transport.loop.create_server = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert mocker.return_value == await transport.transport_listen() - transport.close() -class TestBasicUdpTransport: - """Test transport module, udp part.""" +class TestBasicNullModem: + """Test transport null modem module.""" - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_udp(setup_server, params.host, params.port) - commparams.host = params.host - commparams.port = params.port - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_udp().""" - transport.setup_udp(False, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_datagram_endpoint = mocker - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_udp().""" - transport.setup_udp(True, params.host, params.port) - mocker = mock.AsyncMock() - transport.loop.create_datagram_endpoint = mocker - mocker.side_effect = OSError("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_listen() == mocker.return_value[0] - transport.close() - - -class TestBasicSerialTransport: - """Test transport module, serial part.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("setup_server", [True, False]) - def test_properties(self, params, setup_server, transport, commparams): - """Test properties.""" - transport.setup_serial( - setup_server, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - commparams.host = params.host - commparams.baudrate = params.baudrate - commparams.bytesize = params.bytesize - commparams.parity = params.parity - commparams.stopbits = params.stopbits - assert transport.comm_params == commparams - assert transport.call_connect_listen - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connect(self, params, transport): - """Test connect_serial().""" - transport.setup_serial( - False, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - mocker = mock.AsyncMock() - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = asyncio.TimeoutError("testing") - assert not await transport.transport_connect() - mocker.side_effect = None - - mocker.return_value = (mock.Mock(), mock.Mock()) - assert await transport.transport_connect() - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - async def test_listen(self, params, transport): - """Test listen_serial().""" - transport.setup_serial( - True, - params.host, - params.baudrate, - params.bytesize, - params.parity, - params.stopbits, - ) - mocker = mock.AsyncMock() - with mock.patch( - "pymodbus.transport.transport.create_serial_connection", new=mocker - ): - mocker.side_effect = SerialException("testing") - assert await transport.transport_listen() is None - mocker.side_effect = None - - mocker.return_value = mock.Mock() - assert await transport.transport_listen() == mocker.return_value - transport.close() + def test_init(self): + """Test null modem init""" + NullModem.server = None + with pytest.raises(RuntimeError): + NullModem(False, mock.Mock()) + NullModem(True, mock.Mock()) + NullModem(False, mock.Mock()) + + def test_external_methods(self): + """Test external methods.""" + modem = NullModem(True, mock.Mock()) + modem.close() + modem.sendto(b"abcd") + modem.write(b"abcd") + + def test_abstract_methods(self): + """Test asyncio abstract methods.""" + modem = NullModem(True, mock.Mock()) + modem.abort() + modem.can_write_eof() + modem.get_write_buffer_size() + modem.get_write_buffer_limits() + modem.set_write_buffer_limits(1024, 1) + modem.write_eof() + modem.get_protocol() + modem.set_protocol(None) + modem.is_closing() diff --git a/test/sub_transport/test_comm.py b/test/sub_transport/test_comm.py index bfbfc5f8d..3d67cacc4 100644 --- a/test/sub_transport/test_comm.py +++ b/test/sub_transport/test_comm.py @@ -1,241 +1,133 @@ """Test transport.""" +import asyncio import time import pytest +from pymodbus.transport.transport import CommType -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="not implemented.") -class TestCommUnixTransport: - """Test for the transport module.""" - - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, domain_socket): - """Test connect_unix().""" - transport.setup_unix(False, domain_socket) - start = time.time() - assert not await transport.transport_connect() - delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, positive, domain_socket): - """Test listen_unix().""" - transport_server.setup_unix(True, domain_socket) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - if server: - server.close() - transport_server.close() +BASE_PORT = 6100 +FACTOR = 1.2 if not pytest.IS_WINDOWS else 2.2 - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, domain_socket): - """Test listen/connect unix().""" - transport_server.setup_unix(True, domain_socket) - await transport_server.transport_listen() - transport.setup_unix(False, domain_socket) - assert await transport.transport_connect() - transport.close() - transport_server.close() - - -class TestCommTcpTransport: +class TestCommTransport: """Test for the transport module.""" - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, use_port, domain_host): - """Test connect_tcp().""" - transport.setup_tcp(False, domain_host, use_port) + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 1), + (CommType.TLS, BASE_PORT + 2), + # (CommType.UDP, BASE_PORT + 3), udp is connectionless. + (CommType.SERIAL, BASE_PORT + 4), + ], + ) + async def test_connect(self, client): + """Test connect().""" start = time.time() - assert not await transport.transport_connect() + assert not await client.transport_connect() delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, use_port, positive, domain_host): - """Test listen_tcp().""" - transport_server.setup_tcp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, use_port, domain_host): - """Test listen/connect tcp().""" - transport_server.setup_tcp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_tcp(False, domain_host, use_port) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() - - -class TestCommTlsTransport: - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, params, domain_host): - """Test connect_tls().""" - transport.setup_tls( - False, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) - start = time.time() - assert not await transport.transport_connect() - delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, params, positive, domain_host): - """Test listen_tls().""" - transport_server.setup_tls( - True, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, params, domain_host): - """Test listen/connect tls().""" - transport_server.setup_tls( - True, - domain_host, - params.port, - None, - params.cwd + "crt", - params.cwd + "key", - None, - "localhost", - ) - server = await transport_server.transport_listen() - assert server - - transport.setup_tcp(False, domain_host, params.port) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() - - -class TestCommUdpTransport: - """Test for the transport module.""" - - async def test_connect(self): - """Test connect_udp().""" - # always true, since udp is connectionless. - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_listen(self, transport_server, use_port, positive, domain_host): - """Test listen_udp().""" - transport_server.setup_udp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert positive == bool(server) - assert positive == bool(transport_server.transport) - transport_server.close() - if server: - server.close() - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True]) - async def test_connected(self, transport, transport_server, use_port, domain_host): - """Test listen/connect udp().""" - transport_server.setup_udp(True, domain_host, use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_udp(False, domain_host, use_port) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() - - -class TestCommSerialTransport: - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - @pytest.mark.parametrize("positive", [True, False]) - async def test_connect(self, transport, use_port, positive): - """Test connect_serial().""" - domain_port = f"unix:/localhost:{use_port}" if positive else "/illegal_port" - transport.setup_serial( - False, - domain_port, - 9600, - 8, - "E", - 2, - ) + assert delta < client.comm_params.timeout_connect * FACTOR + client.transport_close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 5), + (CommType.TLS, BASE_PORT + 6), + # (CommType.UDP, BASE_PORT + 7), udp is connectionless. + (CommType.SERIAL, BASE_PORT + 8), + ], + ) + async def test_connect_not_ok(self, client): + """Test connect().""" + client.comm_params.host = "/illegal_host" start = time.time() - assert not await transport.transport_connect() + assert not await client.transport_connect() delta = time.time() - start - assert delta < transport.comm_params.timeout_connect * 1.2 - transport.close() - - async def test_listen(self, transport_server): - """Test listen_serial().""" - transport_server.setup_serial( - True, - "/illegal_port", - 9600, - 8, - "E", - 2, - ) - server = await transport_server.transport_listen() - assert not server - assert not transport_server.transport - transport_server.close() - - # there are no positive test, since there are no standard tty port - - @pytest.mark.xdist_group(name="server_serialize") - async def test_connected(self, transport, transport_server, use_port): - """Test listen/connect serial().""" - transport_server.setup_tcp(True, "localhost", use_port) - server = await transport_server.transport_listen() - assert server - transport.setup_serial( - False, - f"socket://localhost:{use_port}", - 9600, - 8, - "E", - 2, - ) - assert await transport.transport_connect() - transport.close() - transport_server.close() - server.close() + assert delta < client.comm_params.timeout_connect * FACTOR + client.transport_close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 9), + (CommType.TLS, BASE_PORT + 10), + (CommType.UDP, BASE_PORT + 11), + # (CommType.SERIAL, BASE_PORT + 12), there are no standard tty port + ], + ) + async def test_listen(self, server): + """Test listen().""" + assert await server.transport_listen() + assert server.transport + server.transport_close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 13), + (CommType.TLS, BASE_PORT + 14), + (CommType.UDP, BASE_PORT + 15), + (CommType.SERIAL, BASE_PORT + 16), + ], + ) + async def test_listen_not_ok(self, server): + """Test listen().""" + server.comm_params.host = "/illegal_host" + assert not await server.transport_listen() + assert not server.transport + server.transport_close() + + @pytest.mark.parametrize( + ("use_comm_type", "use_port"), + [ + (CommType.TCP, BASE_PORT + 13), + (CommType.TLS, BASE_PORT + 14), + (CommType.UDP, BASE_PORT + 15), + (CommType.SERIAL, BASE_PORT + 16), + ], + ) + async def test_connected(self, client, server, use_comm_type): + """Test connection and data exchange.""" + assert await server.transport_listen() + assert await client.transport_connect() + await asyncio.sleep(0.5) + assert len(server.active_connections) == 1 + server_connected = list(server.active_connections.values())[0] + test_data = b"abcd" + client.transport_send(test_data) + await asyncio.sleep(0.5) + assert server_connected.recv_buffer == test_data + assert not client.recv_buffer + server_connected.recv_buffer = b"" + if use_comm_type == CommType.UDP: + sock = client.transport.get_extra_info("socket") + addr = sock.getsockname() + server_connected.transport_send(test_data, addr=addr) + else: + server_connected.transport_send(test_data) + await asyncio.sleep(2) + assert client.recv_buffer == test_data + assert not server_connected.recv_buffer + client.transport_close() + server.transport_close() + + +class TestCommNullModem: # pylint: disable=too-few-public-methods + """Test null modem module.""" + + def test_class_variables(self, nullmodem_server, nullmodem): + """Test connection_made().""" + assert nullmodem.client + assert nullmodem.server + assert nullmodem_server.client + assert nullmodem_server.server + nullmodem.__class__.client = self + nullmodem.is_server = False + nullmodem_server.__class__.server = self + nullmodem_server.is_server = True + + assert nullmodem.client == nullmodem_server.client + assert nullmodem.server == nullmodem_server.server diff --git a/test/sub_transport/test_data.py b/test/sub_transport/test_data.py deleted file mode 100644 index 6d9535e4d..000000000 --- a/test/sub_transport/test_data.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Test transport.""" -import asyncio - -import pytest - - -class TestDataTransport: # pylint: disable=too-few-public-methods - """Test for the transport module.""" - - @pytest.mark.xdist_group(name="server_serialize") - async def test_client_send(self, transport, transport_server, use_port): - """Test send().""" - transport_server.setup_tcp(True, "localhost", use_port) - server = await transport_server.transport_listen() - assert transport_server.transport - - transport.setup_tcp(False, "localhost", use_port) - assert await transport.transport_connect() - await transport.send(b"ABC") - await asyncio.sleep(2) - assert transport_server.recv_buffer == b"ABC" - await transport_server.send(b"DEF") - await asyncio.sleep(2) - assert transport.recv_buffer == b"DEF" - transport.close() - transport_server.close() - server.close() diff --git a/test/sub_transport/test_nullmodem.py b/test/sub_transport/test_nullmodem.py deleted file mode 100644 index 3e61765e7..000000000 --- a/test/sub_transport/test_nullmodem.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test transport.""" - -from pymodbus.transport.nullmodem import DummyTransport - - -class TestNullModemTransport: - """Test null modem module.""" - - async def test_str_magic(self, nullmodem, params): - """Test magic.""" - str(nullmodem) - assert str(nullmodem) == f"NullModem({params.comm_name})" - - def test_DummyTransport(self): - """Test DummyTransport class.""" - socket = DummyTransport() - socket.close() - socket.get_protocol() - socket.is_closing() - socket.set_protocol(None) - socket.abort() - - def test_class_variables(self, nullmodem, nullmodem_server): - """Test connection_made().""" - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - assert not nullmodem_server.nullmodem_client - assert not nullmodem_server.nullmodem_server - nullmodem.__class__.nullmodem_client = self - nullmodem.is_server = False - nullmodem_server.__class__.nullmodem_server = self - nullmodem_server.is_server = True - - assert nullmodem.nullmodem_client == nullmodem_server.nullmodem_client - assert nullmodem.nullmodem_server == nullmodem_server.nullmodem_server - - async def test_transport_connect(self, nullmodem): - """Test connection_made().""" - nullmodem.loop = None - assert not await nullmodem.transport_connect() - assert not nullmodem.nullmodem_server - assert not nullmodem.nullmodem_client - assert nullmodem.loop - nullmodem.cb_connection_made.assert_not_called() - nullmodem.cb_connection_lost.assert_not_called() - nullmodem.cb_handle_data.assert_not_called() - - async def test_transport_listen(self, nullmodem_server): - """Test connection_made().""" - nullmodem_server.loop = None - assert await nullmodem_server.transport_listen() - assert nullmodem_server.is_server - assert nullmodem_server.nullmodem_server - assert not nullmodem_server.nullmodem_client - assert nullmodem_server.loop - nullmodem_server.cb_connection_made.assert_not_called() - nullmodem_server.cb_connection_lost.assert_not_called() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_connected(self, nullmodem, nullmodem_server): - """Test connection is correct.""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - assert nullmodem.nullmodem_client - assert nullmodem.nullmodem_server - assert nullmodem.loop - assert not nullmodem.is_server - assert nullmodem_server.is_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_not_called() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_not_called() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_client_close(self, nullmodem, nullmodem_server): - """Test close().""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - nullmodem.close() - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_called_once() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_called_once() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_server_close(self, nullmodem, nullmodem_server): - """Test close().""" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - nullmodem_server.close() - assert not nullmodem.nullmodem_client - assert not nullmodem.nullmodem_server - nullmodem.cb_connection_made.assert_called_once() - nullmodem.cb_connection_lost.assert_called_once() - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_connection_made.assert_called_once() - nullmodem_server.cb_connection_lost.assert_called_once() - nullmodem_server.cb_handle_data.assert_not_called() - - async def test_data(self, nullmodem, nullmodem_server): - """Test data exchange.""" - data = b"abcd" - assert await nullmodem_server.transport_listen() - assert await nullmodem.transport_connect() - assert await nullmodem.send(data) - assert nullmodem_server.recv_buffer == data - assert not nullmodem.recv_buffer - nullmodem.cb_handle_data.assert_not_called() - nullmodem_server.cb_handle_data.assert_called_once() - assert await nullmodem_server.send(data) - assert nullmodem_server.recv_buffer == data - assert nullmodem.recv_buffer == data - nullmodem.cb_handle_data.assert_called_once() - nullmodem_server.cb_handle_data.assert_called_once() diff --git a/test/sub_transport/test_reconnect.py b/test/sub_transport/test_reconnect.py index defb815c9..b1dd5752d 100644 --- a/test/sub_transport/test_reconnect.py +++ b/test/sub_transport/test_reconnect.py @@ -2,68 +2,60 @@ import asyncio from unittest import mock -import pytest - class TestReconnectTransport: """Test transport module, base part.""" - @pytest.mark.xdist_group(name="server_serialize") - async def test_no_reconnect_call(self, transport, use_port, commparams): + async def test_no_reconnect_call(self, client): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - assert not mocker.call_count - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.connection_lost(RuntimeError("Connection lost")) - assert not mocker.call_count - assert transport.reconnect_delay_current == commparams.reconnect_delay - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_lost(RuntimeError("Connection lost")) + assert not client.reconnect_task + assert client.loop.create_connection.call_count + assert not client.reconnect_delay_current + client.transport_close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_reconnect_call(self, transport, use_port, commparams): + async def test_reconnect_call(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + assert client.reconnect_task + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.reconnect_task + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + client.transport_close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_multi_reconnect_call(self, transport, use_port, commparams): + async def test_multi_reconnect_call(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(None, None)) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay * 2 - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 2 - assert transport.reconnect_delay_current == commparams.reconnect_delay_max - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count >= 3 - assert transport.reconnect_delay_current == commparams.reconnect_delay_max - transport.close() + client.loop.create_connection = mock.AsyncMock(return_value=(None, None)) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay * 2 + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 3 + assert client.reconnect_delay_current == commparams.reconnect_delay_max + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count >= 4 + assert client.reconnect_delay_current == commparams.reconnect_delay_max + client.transport_close() - @pytest.mark.xdist_group(name="server_serialize") - async def test_reconnect_call_ok(self, transport, use_port, commparams): + async def test_reconnect_call_ok(self, client, commparams): """Test connection_lost().""" - transport.setup_tcp(False, "localhost", use_port) - mocker = mock.AsyncMock(return_value=(mock.Mock(), mock.Mock())) - transport.loop.create_connection = mocker - transport.connection_made(mock.Mock()) - transport.connection_lost(RuntimeError("Connection lost")) - await asyncio.sleep(transport.reconnect_delay_current * 1.8) - assert mocker.call_count == 1 - assert transport.reconnect_delay_current == commparams.reconnect_delay - assert not transport.reconnect_task - transport.close() + client.loop.create_connection = mock.AsyncMock( + return_value=(mock.Mock(), mock.Mock()) + ) + await client.transport_connect() + client.connection_made(mock.Mock()) + client.connection_lost(RuntimeError("Connection lost")) + await asyncio.sleep(client.reconnect_delay_current * 1.8) + assert client.loop.create_connection.call_count == 2 + assert client.reconnect_delay_current == commparams.reconnect_delay + assert not client.reconnect_task + client.transport_close() diff --git a/test/test_client.py b/test/test_client.py index f628dfd26..6d0e4337a 100755 --- a/test/test_client.py +++ b/test/test_client.py @@ -103,7 +103,6 @@ def fake_execute(_self, request): assert isinstance(pdu_to_call, pdu_request) -@pytest.mark.xdist_group(name="client") @pytest.mark.parametrize( "arg_list", [ @@ -269,7 +268,7 @@ async def test_client_connection_made(): """Test protocol made connection.""" client = lib_client.AsyncModbusTcpClient("127.0.0.1") assert not client.connected - client.new_transport.connection_made(mock.AsyncMock()) + client.connection_made(mock.AsyncMock()) assert client.connected client.close() @@ -299,15 +298,15 @@ async def test_client_protocol_receiver(): """Test the client protocol data received""" base = ModbusBaseClient(framer=ModbusSocketFramer) transport = mock.MagicMock() - base.new_transport.connection_made(transport) - assert base.new_transport.transport == transport - assert base.new_transport.transport + base.connection_made(transport) + assert base.transport == transport + assert base.transport data = b"\x00\x00\x12\x34\x00\x06\xff\x01\x01\x02\x00\x04" # setup existing request assert not list(base.transaction) response = base._build_response(0x00) # pylint: disable=protected-access - base.new_transport.data_received(data) + base.data_received(data) result = response.result() assert isinstance(result, pdu_bit_read.ReadCoilsResponse) @@ -334,7 +333,7 @@ async def test_client_protocol_handler(): """Test the client protocol handles responses""" base = ModbusBaseClient(framer=ModbusSocketFramer) transport = mock.MagicMock() - base.new_transport.connection_made(transport=transport) + base.connection_made(transport=transport) reply = pdu_bit_read.ReadCoilsRequest(1, 1) reply.transaction_id = 0x00 base._handle_response(None) # pylint: disable=protected-access @@ -350,7 +349,7 @@ async def test_client_protocol_execute(): """Test the client protocol execute method""" base = ModbusBaseClient(host="127.0.0.1", framer=ModbusSocketFramer) transport = mock.MagicMock() - base.new_transport.connection_made(transport) + base.connection_made(transport) base.transport.write = mock.Mock() request = pdu_bit_read.ReadCoilsRequest(1, 1) @@ -364,15 +363,6 @@ async def test_client_protocol_execute(): response = await base.async_execute(request) -def test_client_udp(): - """Test client udp.""" - base = ModbusBaseClient(host="127.0.0.1", framer=ModbusSocketFramer) - base.new_transport.datagram_received(bytes("00010000", "utf-8"), 1) - base.transport = mock.MagicMock() - base.use_udp = True - base.new_transport.send(bytes("00010000", "utf-8")) - - def test_client_udp_connect(): """Test the Udp client connection method""" with mock.patch.object(socket, "socket") as mock_method: diff --git a/test/test_client_sync.py b/test/test_client_sync.py index d093e1915..cca28f252 100755 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -1,5 +1,4 @@ """Test client sync.""" -import ssl from itertools import count from test.conftest import mockSocket from unittest import mock @@ -13,7 +12,6 @@ ModbusTlsClient, ModbusUdpClient, ) -from pymodbus.client.tls import sslctx_provider from pymodbus.exceptions import ConnectionException from pymodbus.transaction import ( ModbusAsciiFramer, @@ -208,28 +206,6 @@ class CustomRequest: # pylint: disable=too-few-public-methods # Test TLS Client # -----------------------------------------------------------------------# - def test_tls_sslctx_provider(self): - """Test that sslctx_provider() produce SSLContext correctly""" - with mock.patch.object(ssl.SSLContext, "load_cert_chain") as mock_method: - sslctx1 = sslctx_provider(certfile="cert.pem") - assert sslctx1 - assert isinstance(sslctx1, ssl.SSLContext) - assert not mock_method.called - - sslctx2 = sslctx_provider(keyfile="key.pem") - assert sslctx2 - assert isinstance(sslctx2, ssl.SSLContext) - assert not mock_method.called - - sslctx3 = sslctx_provider(certfile="cert.pem", keyfile="key.pem") - assert sslctx3 - assert isinstance(sslctx3, ssl.SSLContext) - assert mock_method.called - - sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - sslctx_new = sslctx_provider(sslctx=sslctx_old) - assert sslctx_new == sslctx_old - def test_syn_tls_client_instantiation(self): """Test sync tls client.""" # default SSLContext diff --git a/test/test_server_asyncio.py b/test/test_server_asyncio.py index 5f63a5962..4ef5f5e7b 100755 --- a/test/test_server_asyncio.py +++ b/test/test_server_asyncio.py @@ -229,11 +229,11 @@ async def test_async_tcp_server_connection_lost(self): """Test tcp stream interruption""" await self.start_server() await self.connect_server() - assert len(self.server.active_connections), 1 + assert len(self.server.local_active_connections), 1 BasicClient.transport.close() await asyncio.sleep(0.2) # so we have to wait a bit - assert not self.server.active_connections + assert not self.server.local_active_connections async def test_async_tcp_server_close_connection(self): """Test server_close() while there are active TCP connections""" @@ -276,14 +276,12 @@ async def test_async_start_tls_server_no_loop(self): with mock.patch.object(ssl.SSLContext, "load_cert_chain"): await self.start_server(do_tls=True, do_forever=False, do_ident=True) assert self.server.control.Identity.VendorName == "VendorName" - assert self.server.sslctx async def test_async_start_tls_server(self): """Test that the modbus tls asyncio server starts correctly""" with mock.patch.object(ssl.SSLContext, "load_cert_chain"): await self.start_server(do_tls=True, do_ident=True) assert self.server.control.Identity.VendorName == "VendorName" - assert self.server.sslctx async def test_async_tls_server_serve_forever(self): """Test StartAsyncTcpServer serve_forever() method""" diff --git a/test/test_unix_socket.py b/test/test_unix_socket.py deleted file mode 100755 index 28186d6d2..000000000 --- a/test/test_unix_socket.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Test client async.""" -import asyncio -from tempfile import gettempdir - -import pytest -import pytest_asyncio - -from pymodbus.client import AsyncModbusTcpClient -from pymodbus.datastore import ( - ModbusSequentialDataBlock, - ModbusServerContext, - ModbusSlaveContext, -) -from pymodbus.server import ServerAsyncStop, StartAsyncUnixServer -from pymodbus.transaction import ModbusSocketFramer - - -PATH = gettempdir() + "/unix_domain_socket" -HOST = f"unix:{PATH}" - - -@pytest_asyncio.fixture(name="_mock_run_server") -async def _helper_server(path_addon): - """Run server.""" - datablock = ModbusSequentialDataBlock(0x00, [17] * 100) - context = ModbusSlaveContext( - di=datablock, co=datablock, hr=datablock, ir=datablock, slave=1 - ) - asyncio.create_task( # noqa: RUF006 - StartAsyncUnixServer( - context=ModbusServerContext(slaves=context, single=True), - path=PATH + path_addon, - framer=ModbusSocketFramer, - ) - ) - await asyncio.sleep(0.1) - yield - await ServerAsyncStop() - - -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows have a timeout problem.") -@pytest.mark.parametrize("path_addon", ["_1"]) -async def test_unix_server(_mock_run_server): - """Run async server with unix domain socket.""" - await asyncio.sleep(0.1) - - -@pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows have a timeout problem.") -@pytest.mark.parametrize("path_addon", ["_2"]) -async def test_unix_async_client(path_addon, _mock_run_server): - """Run async client with unix domain socket.""" - await asyncio.sleep(1) - client = AsyncModbusTcpClient( - HOST + path_addon, - framer=ModbusSocketFramer, - ) - await client.connect() - assert client.connected - - rr = await client.read_coils(1, 1, slave=1) - assert not rr.isError()