diff --git a/examples/client_async.py b/examples/client_async.py index 332a1a04c..6874d5262 100755 --- a/examples/client_async.py +++ b/examples/client_async.py @@ -124,7 +124,7 @@ async def run_async_client(client, modbus_calls=None): assert client.connected if modbus_calls: await modbus_calls(client) - await client.close() + client.close() _logger.info("### End of Program") diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index e6b7a4f58..deb4d5c10 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -14,10 +14,11 @@ from pymodbus.logging import Log from pymodbus.pdu import ModbusRequest, ModbusResponse from pymodbus.transaction import DictTransactionManager +from pymodbus.transport import BaseTransport from pymodbus.utilities import ModbusTransactionState -class ModbusBaseClient(ModbusClientMixin): +class ModbusBaseClient(ModbusClientMixin, BaseTransport): """**ModbusBaseClient** **Parameters common to all clients**: @@ -63,8 +64,6 @@ class _params: # pylint: disable=too-many-instance-attributes broadcast_enable: bool = None kwargs: dict = None reconnect_delay: int = None - reconnect_delay_max: int = None - on_reconnect_callback: Callable[[], None] | None = None baudrate: int = None bytesize: int = None @@ -95,6 +94,7 @@ def __init__( # pylint: disable=too-many-arguments **kwargs: Any, ) -> None: """Initialize a client instance.""" + BaseTransport.__init__(self) self.params = self._params() self.params.framer = framer self.params.timeout = float(timeout) @@ -104,8 +104,8 @@ def __init__( # pylint: disable=too-many-arguments self.params.strict = bool(strict) self.params.broadcast_enable = bool(broadcast_enable) self.params.reconnect_delay = int(reconnect_delay) - self.params.reconnect_delay_max = int(reconnect_delay_max) - self.params.on_reconnect_callback = on_reconnect_callback + self.reconnect_delay_max = int(reconnect_delay_max) + self.on_reconnect_callback = on_reconnect_callback self.params.kwargs = kwargs # Common variables. @@ -115,15 +115,14 @@ def __init__( # pylint: disable=too-many-arguments ) self.delay_ms = self.params.reconnect_delay self.use_protocol = False - self._connected = False self.use_udp = False self.state = ModbusTransactionState.IDLE self.last_frame_end: float = 0 self.silent_interval: float = 0 - self.transport = None + self._reconnect_task = None # Initialize mixin - super().__init__() + ModbusClientMixin.__init__(self) # ----------------------------------------------------------------------- # # Client external interface @@ -174,30 +173,16 @@ def execute(self, request: ModbusRequest = None) -> ModbusResponse: :raises ConnectionException: Check exception text. """ if self.use_protocol: - if not self._connected: + if not self.transport: raise ConnectionException(f"Not connected[{str(self)}]") return self.async_execute(request) if not self.connect(): raise ConnectionException(f"Failed to connect[{str(self)}]") return self.transaction.execute(request) - def close(self) -> None: - """Close the underlying socket connection (call **sync/async**).""" - raise NotImplementedException - # ----------------------------------------------------------------------- # # Merged client methods # ----------------------------------------------------------------------- # - def client_made_connection(self, protocol): - """Run transport specific connection.""" - - def client_lost_connection(self, protocol): - """Run transport specific connection lost.""" - - def datagram_received(self, data, _addr): - """Receive datagram.""" - self.data_received(data) - async def async_execute(self, request=None): """Execute requests asynchronously.""" request.transaction_id = self.transaction.getNextTID() @@ -218,29 +203,13 @@ async def async_execute(self, request=None): raise return resp - def connection_made(self, transport): - """Call when a connection is made. - - The transport argument is the transport representing the connection. - """ - self.transport = transport - Log.debug("Client connected to modbus server") - self._connected = True - self.client_made_connection(self) - def connection_lost(self, reason): """Call when the connection is lost or closed. The argument is either an exception object or None """ - if self.transport: - self.transport.abort() - if hasattr(self.transport, "_sock"): - self.transport._sock.close() # pylint: disable=protected-access - self.transport = None - self.client_lost_connection(self) Log.debug("Client disconnected from modbus server: {}", reason) - self._connected = False + self.close(reconnect=True) for tid in list(self.transaction): self.raise_future( self.transaction.getTransaction(tid), @@ -277,22 +246,43 @@ def _handle_response(self, reply, **_kwargs): def _build_response(self, tid): """Return a deferred response for the current request.""" my_future = self.create_future() - if not self._connected: + if not self.transport: self.raise_future(my_future, ConnectionException("Client is not connected")) else: self.transaction.addTransaction(my_future, tid) return my_future - @property - def async_connected(self): - """Return connection status.""" - return self._connected + def close(self, reconnect: bool = False) -> None: + """Close connection. - async def async_close(self): - """Close connection.""" + :param reconnect: (default false), try to reconnect + """ if self.transport: + if hasattr(self.transport, "_sock"): + self.transport._sock.close() # pylint: disable=protected-access + self.transport.abort() self.transport.close() - self._connected = False + self.transport = None + if self._reconnect_task: + self._reconnect_task.cancel() + self._reconnect_task = None + + if not reconnect or not self.delay_ms: + self.delay_ms = 0 + return + + self._reconnect_task = asyncio.create_task(self._reconnect()) + + async def _reconnect(self): + """Reconnect.""" + Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms) + await asyncio.sleep(self.delay_ms / 1000) + self.delay_ms = min(2 * self.delay_ms, self.reconnect_delay_max) + + self._reconnect_task = None + if self.on_reconnect_callback: + self.on_reconnect_callback() + return await self.connect() # ----------------------------------------------------------------------- # # Internal methods @@ -353,7 +343,7 @@ def __exit__(self, klass, value, traceback): async def __aexit__(self, klass, value, traceback): """Implement the client with exit block.""" - await self.close() + self.close() def __str__(self): """Build a string representation of the connection. diff --git a/pymodbus/client/serial.py b/pymodbus/client/serial.py index 9b99b8a19..e18a63618 100644 --- a/pymodbus/client/serial.py +++ b/pymodbus/client/serial.py @@ -42,7 +42,7 @@ async def run(): await client.connect() ... - await client.close() + client.close() """ transport = None @@ -68,25 +68,6 @@ def __init__( self.params.parity = parity self.params.stopbits = stopbits self.params.handle_local_echo = handle_local_echo - self.loop = None - self._connected_event = asyncio.Event() - self._reconnect_task = None - - async def close(self): # pylint: disable=invalid-overridden-method - """Stop connection.""" - - # prevent reconnect: - self.delay_ms = 0 - if self.connected: - if self.transport: - self.transport.close() - await self.async_close() - await asyncio.sleep(0.1) - - # if there is an unfinished delayed reconnection attempt pending, cancel it - if self._reconnect_task: - self._reconnect_task.cancel() - self._reconnect_task = None def _create_protocol(self): """Create a protocol instance.""" @@ -95,74 +76,33 @@ def _create_protocol(self): @property def connected(self): """Connect internal.""" - return self._connected_event.is_set() + return self.transport is not None async def connect(self): # pylint: disable=invalid-overridden-method """Connect Async client.""" # get current loop, if there are no loop a RuntimeError will be raised - self.loop = asyncio.get_running_loop() - Log.debug("Starting serial connection") try: - await create_serial_connection( - self.loop, - self._create_protocol, - self.params.port, - baudrate=self.params.baudrate, - bytesize=self.params.bytesize, - stopbits=self.params.stopbits, - parity=self.params.parity, + await asyncio.wait_for( + create_serial_connection( + self.loop, + self._create_protocol, + self.params.port, + baudrate=self.params.baudrate, + bytesize=self.params.bytesize, + stopbits=self.params.stopbits, + parity=self.params.parity, + timeout=self.params.timeout, + **self.params.kwargs, + ), timeout=self.params.timeout, - **self.params.kwargs, ) - await self._connected_event.wait() Log.info("Connected to {}", self.params.port) except Exception as exc: # pylint: disable=broad-except Log.warning("Failed to connect: {}", exc) - if self.delay_ms > 0: - self._launch_reconnect() + self.close(reconnect=True) return self.connected - def client_made_connection(self, protocol): - """Notify successful connection.""" - Log.info("Serial connected.") - if not self.connected: - self._connected_event.set() - else: - Log.error("Factory protocol connect callback called while connected.") - - def client_lost_connection(self, protocol): - """Notify lost connection.""" - Log.info("Serial lost connection.") - if protocol is not self: - Log.error("Serial: protocol is not self.") - - self._connected_event.clear() - if self.delay_ms: - self._launch_reconnect() - - def _launch_reconnect(self): - """Launch delayed reconnection coroutine""" - if self._reconnect_task: - Log.warning( - "Ignoring launch of delayed reconnection, another is in progress" - ) - else: - # store the future in a member variable so we know we have a pending reconnection attempt - # also prevents its garbage collection - self._reconnect_task = asyncio.create_task(self._reconnect()) - - async def _reconnect(self): - """Reconnect.""" - Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms) - await asyncio.sleep(self.delay_ms / 1000) - self.delay_ms = min(2 * self.delay_ms, self.params.reconnect_delay_max) - - self._reconnect_task = None - if self.params.on_reconnect_callback: - self.params.on_reconnect_callback() - return await self.connect() - class ModbusSerialClient(ModbusBaseClient): """**ModbusSerialClient**. @@ -267,7 +207,7 @@ def connect(self): self.close() return self.socket is not None - def close(self): + def close(self): # pylint: disable=arguments-differ """Close the underlying socket connection.""" if self.socket: self.socket.close() diff --git a/pymodbus/client/tcp.py b/pymodbus/client/tcp.py index d4ec2f4a0..a87f2942a 100644 --- a/pymodbus/client/tcp.py +++ b/pymodbus/client/tcp.py @@ -34,7 +34,7 @@ async def run(): await client.connect() ... - await client.close() + client.close() """ def __init__( @@ -51,10 +51,7 @@ def __init__( self.params.host = host self.params.port = port self.params.source_address = source_address - self.loop = None - self.connected = False self.delay_ms = self.params.reconnect_delay - self._reconnect_task = None async def connect(self): # pylint: disable=invalid-overridden-method """Initiate connection to start client.""" @@ -64,23 +61,13 @@ async def connect(self): # pylint: disable=invalid-overridden-method self.reset_delay() # force reconnect if required: - self.loop = asyncio.get_running_loop() Log.debug("Connecting to {}:{}.", self.params.host, self.params.port) return await self._connect() - async def close(self): # pylint: disable=invalid-overridden-method - """Stop client.""" - self.delay_ms = 0 - if self.connected: - if self.transport: - self.transport.abort() - self.transport.close() - await self.async_close() - await asyncio.sleep(0.1) - - if self._reconnect_task: - self._reconnect_task.cancel() - self._reconnect_task = None + @property + def connected(self): + """Return true if connected.""" + return self.transport is not None def _create_protocol(self): """Create initialized protocol instance with function.""" @@ -108,51 +95,12 @@ async def _connect(self): ) except Exception as exc: # pylint: disable=broad-except Log.warning("Failed to connect: {}", exc) - if self.delay_ms > 0: - self._launch_reconnect() + self.close(reconnect=True) else: Log.info("Connected to {}:{}.", self.params.host, self.params.port) self.reset_delay() return transport, protocol - def client_made_connection(self, protocol): - """Notify successful connection.""" - Log.info("Protocol made connection.") - if not self.connected: - self.connected = True - else: - Log.error("Factory protocol connect callback called while connected.") - - def client_lost_connection(self, protocol): - """Notify lost connection.""" - Log.info("Protocol lost connection.") - if protocol is not self: - Log.error("Factory protocol cb from unknown protocol instance.") - - self.connected = False - if self.delay_ms > 0: - self._launch_reconnect() - - def _launch_reconnect(self): - """Launch delayed reconnection coroutine""" - if self._reconnect_task: - Log.warning( - "Ignoring launch of delayed reconnection, another is in progress" - ) - else: - self._reconnect_task = asyncio.create_task(self._reconnect()) - - async def _reconnect(self): - """Reconnect.""" - Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms) - await asyncio.sleep(self.delay_ms / 1000) - self.delay_ms = min(2 * self.delay_ms, self.params.reconnect_delay_max) - - self._reconnect_task = None - if self.params.on_reconnect_callback: - self.params.on_reconnect_callback() - return await self._connect() - class ModbusTcpClient(ModbusBaseClient): """**ModbusTcpClient**. @@ -228,7 +176,7 @@ def connect(self): self.close() return self.socket is not None - def close(self): + def close(self): # pylint: disable=arguments-differ """Close the underlying socket connection.""" if self.socket: self.socket.close() diff --git a/pymodbus/client/tls.py b/pymodbus/client/tls.py index e2adbcb79..3894b0035 100644 --- a/pymodbus/client/tls.py +++ b/pymodbus/client/tls.py @@ -66,7 +66,7 @@ async def run(): await client.connect() ... - await client.close() + client.close() """ def __init__( @@ -104,8 +104,7 @@ async def _connect(self): ) except Exception as exc: # pylint: disable=broad-except Log.warning("Failed to connect: {}", exc) - if self.delay_ms > 0: - self._launch_reconnect() + self.close(reconnect=True) return Log.info("Connected to {}:{}.", self.params.host, self.params.port) self.reset_delay() diff --git a/pymodbus/client/udp.py b/pymodbus/client/udp.py index 5cef1ee86..8a4721a9b 100644 --- a/pymodbus/client/udp.py +++ b/pymodbus/client/udp.py @@ -37,7 +37,7 @@ async def run(): await client.connect() ... - await client.close() + client.close() """ def __init__( @@ -54,47 +54,30 @@ def __init__( self.params.host = host self.params.port = port self.params.source_address = source_address - self._reconnect_task = None - self.loop = asyncio.get_event_loop() - self.connected = False self.delay_ms = self.params.reconnect_delay - self._reconnect_task = None self.reset_delay() + @property + def connected(self): + """Return true if connected.""" + return self.transport is not None + async def connect(self): # pylint: disable=invalid-overridden-method """Start reconnecting asynchronous udp client. :meta private: """ # get current loop, if there are no loop a RuntimeError will be raised - self.loop = asyncio.get_running_loop() Log.debug("Connecting to {}:{}.", self.params.host, self.params.port) # getaddrinfo returns a list of tuples # - [(family, type, proto, canonname, sockaddr),] # We want sockaddr which is a (ip, port) tuple # udp needs ip addresses, not hostnames - # TBD: addrinfo = await self.loop.getaddrinfo(self.params.host, self.params.port, type=DGRAM_TYPE) + # TBD: addrinfo = await getaddrinfo(self.params.host, self.params.port, type=DGRAM_TYPE) # TBD: self.params.host, self.params.port = addrinfo[-1][-1] return await self._connect() - async def close(self): # pylint: disable=invalid-overridden-method - """Stop connection and prevents reconnect. - - :meta private: - """ - self.delay_ms = 0 - if self.connected: - if self.transport: - self.transport.abort() - self.transport.close() - await self.async_close() - await asyncio.sleep(0.1) - - if self._reconnect_task: - self._reconnect_task.cancel() - self._reconnect_task = None - def _create_protocol(self): """Create initialized protocol instance with function.""" self.use_udp = True @@ -112,50 +95,7 @@ async def _connect(self): return endpoint except Exception as exc: # pylint: disable=broad-except Log.warning("Failed to connect: {}", exc) - self._reconnect_task = asyncio.ensure_future(self._reconnect()) - - def client_made_connection(self, protocol): - """Notify successful connection. - - :meta private: - """ - Log.info("Protocol made connection.") - if not self.connected: - self.connected = True - else: - Log.error("Factory protocol connect callback called while connected.") - - def client_lost_connection(self, protocol): - """Notify lost connection. - - :meta private: - """ - Log.info("Protocol lost connection.") - if protocol is not self: - Log.error("Factory protocol cb from unexpected protocol instance.") - - self.connected = False - if self.delay_ms > 0: - self._launch_reconnect() - - def _launch_reconnect(self): - """Launch delayed reconnection coroutine""" - if self._reconnect_task: - Log.warning( - "Ignoring launch of delayed reconnection, another is in progress" - ) - else: - self._reconnect_task = asyncio.create_task(self._reconnect()) - - async def _reconnect(self): - """Reconnect.""" - Log.debug("Waiting {} ms before next connection attempt.", self.delay_ms) - await asyncio.sleep(self.delay_ms / 1000) - self.delay_ms = 2 * self.delay_ms - - if self.params.on_reconnect_callback: - self.params.on_reconnect_callback() - return await self._connect() + self.close(reconnect=True) class ModbusUdpClient(ModbusBaseClient): @@ -200,14 +140,6 @@ def __init__( self.socket = None - @property - def connected(self): - """Connect internal. - - :meta private: - """ - return self.connect() - def connect(self): """Connect to the modbus tcp server. @@ -224,7 +156,7 @@ def connect(self): self.close() return self.socket is not None - def close(self): + def close(self): # pylint: disable=arguments-differ """Close the underlying socket connection. :meta private: diff --git a/pymodbus/transport/__init__.py b/pymodbus/transport/__init__.py new file mode 100644 index 000000000..731a7e780 --- /dev/null +++ b/pymodbus/transport/__init__.py @@ -0,0 +1,7 @@ +"""Transport.""" + +__all__ = [ + "BaseTransport", +] + +from pymodbus.transport.base import BaseTransport diff --git a/pymodbus/transport/base.py b/pymodbus/transport/base.py new file mode 100644 index 000000000..5fb66b586 --- /dev/null +++ b/pymodbus/transport/base.py @@ -0,0 +1,110 @@ +"""Base for all transport types.""" +from __future__ import annotations + +import asyncio +from typing import Any, Callable + +from pymodbus.framer import ModbusFramer +from pymodbus.logging import Log + + +class BaseTransport: + """Base class for transport types. + + BaseTransport contains functions common to all transport types and client/server. + + This class is not available in the pymodbus API, and should not be referenced in Applications. + """ + + def __init__(self) -> None: + """Initialize a transport instance.""" + # parameter variables, overwritten in child classes + self.framer: ModbusFramer | None = None + # -> framer: framer used to encode/decode data + self.slaves: list[int] = [] + # -> slaves: list of acceptable slaves (0 for accept all) + self.comm_name: str = "" + # -> comm_name: name of this transport connection + self.reconnect_delay: int = -1 + # -> reconnect_delay: delay in milliseconds for first reconnect (0 for no reconnect) + self.reconnect_delay_max: int = -1 + # -> reconnect_delay_max: max delay in milliseconds for next reconnect, resets to reconnect_delay + self.retries_send: int = -1 + # -> retries_send: number of times to retry a send operation + self.retry_on_empty: int = -1 + # -> retry_on_empty: retry read on nothing + self.timeout_connect: bool = None + # -> timeout_connect: Max. time in milliseconds for connect to complete + self.timeout_comm: int = -1 + # -> timeout_comm: Max. time in milliseconds for recv/send to complete + self.on_connection_made: Callable[[str], None] = lambda x: None + # -> on_connection_made: callback when connection is established and opened + self.on_connection_lost: Callable[[str, Exception], None] = lambda x, y: None + # -> on_connection_lost: callback when connection is lost and closed + + # properties, can be read, but may not be mingled with + self.reconnect_delay_current: int = 0 + # -> reconnect_delay_current: current delay in milliseconds for next reconnect (doubles with every try) + self.transport: Any = None + # -> transport: current transport class (None if not connected) + self.loop = asyncio.get_event_loop() + # -> loop: current asyncio event loop + + # -------------------------------------------- # + # Transport external methods (asyncio defined) # + # -------------------------------------------- # + def connection_made(self, transport: Any): + """Call from asyncio, when a connection is made. + + :param transport: socket etc. representing the connection. + """ + self.transport = transport + Log.debug("Connected {}", self.comm_name) + self.on_connection_made(self.comm_name) + + def connection_lost(self, reason: Exception): + """Call from asyncio, when the connection is lost or closed. + + :param reason: None or an exception object + """ + self.transport = None + Log.debug("Connection lost {} due to {}", self.comm_name, reason) + self.on_connection_lost(self.comm_name, reason) + + def data_received(self, data: bytes): + """Call when some data is received. + + :param data: non-empty bytes object with incoming data. + """ + Log.debug("recv: {}", data, ":hex") + # self.framer.processIncomingPacket(data, self._handle_response, unit=0) + + def datagram_received(self, data, _addr): + """Receive datagram.""" + self.data_received(data) + + def send(self, data: bytes) -> bool: + """Send request. + + :param data: non-empty bytes object with data to send. + """ + Log.debug("send: {}", data, ":hex") + return False + + def close(self) -> None: + """Close the underlying connection.""" + + # ----------------------------------------------------------------------- # + # The magic methods + # ----------------------------------------------------------------------- # + async def __aenter__(self): + """Implement the client with async enter block.""" + return self + + async def __aexit__(self, _class, _value, _traceback) -> None: + """Implement the client with async exit block.""" + self.close() + + def __str__(self) -> str: + """Build a string representation of the connection.""" + return f"{self.__class__.__name__}({self.comm_name})" diff --git a/test/test_client.py b/test/test_client.py index d5376ed66..ce7c84592 100755 --- a/test/test_client.py +++ b/test/test_client.py @@ -239,9 +239,6 @@ async def test_client_instanciate( to_test = dict(arg_list["fix"]["opt_args"], **cur_args["opt_args"]) to_test["host"] = cur_args["defaults"]["host"] - for arg, arg_test in to_test.items(): - assert getattr(client.params, arg) == arg_test - # Test information methods client.last_frame_end = 2 client.silent_interval = 2 @@ -264,17 +261,17 @@ async def test_client_instanciate( # a successful execute client.connect = lambda: True - client._connected = True # pylint: disable=protected-access + client.transport = lambda: None client.transaction = mock.Mock(**{"execute.return_value": True}) # a unsuccessful connect client.connect = lambda: False - client._connected = False # pylint: disable=protected-access + client.transport = None with pytest.raises(ConnectionException): client.execute() -def test_client_modbusbaseclient(): +async def test_client_modbusbaseclient(): """Test modbus base client class.""" client = ModbusBaseClient(framer=ModbusAsciiFramer) client.register(pdu_bit_read.ReadCoilsResponse) @@ -286,8 +283,6 @@ def test_client_modbusbaseclient(): client.connect() with pytest.raises(NotImplementedException): client.is_socket_open() - with pytest.raises(NotImplementedException): - client.close() with mock.patch( "pymodbus.client.base.ModbusBaseClient.connect" @@ -302,42 +297,30 @@ def test_client_modbusbaseclient(): p_connect.return_value = False -async def test_client_made_connection(): +async def test_client_connection_made(): """Test protocol made connection.""" client = lib_client.AsyncModbusTcpClient("127.0.0.1") assert not client.connected - client.client_made_connection(mock.sentinel.PROTOCOL) + client.connection_made(mock.sentinel.PROTOCOL) assert client.connected - client.client_made_connection(mock.sentinel.PROTOCOL_UNEXPECTED) + client.connection_made(mock.sentinel.PROTOCOL_UNEXPECTED) assert client.connected -async def test_client_lost_connection(): +async def test_client_connection_lost(): """Test protocol lost connection.""" client = lib_client.AsyncModbusTcpClient("127.0.0.1") assert not client.connected # fake client is connected and *then* looses connection: - client.connected = True client.params.host = mock.sentinel.HOST client.params.port = mock.sentinel.PORT - with mock.patch( - "pymodbus.client.tcp.AsyncModbusTcpClient._launch_reconnect" - ) as mock_reconnect: - mock_reconnect.return_value = mock.sentinel.RECONNECT_GENERATOR - - client.client_lost_connection(mock.sentinel.PROTOCOL_UNEXPECTED) + client.connection_lost(mock.sentinel.PROTOCOL_UNEXPECTED) assert not client.connected - - client.connected = True - with mock.patch( - "pymodbus.client.tcp.AsyncModbusTcpClient._launch_reconnect" - ) as mock_reconnect: - mock_reconnect.return_value = mock.sentinel.RECONNECT_GENERATOR - - client.client_lost_connection(mock.sentinel.PROTOCOL) + client.connection_lost(mock.sentinel.PROTOCOL) assert not client.connected + client.close() async def test_client_base_async(): @@ -361,46 +344,13 @@ async def test_client_base_async(): p_close.return_value.set_result(False) -@pytest.mark.skip -async def test_client_protocol(): - """Test base modbus async client.""" - base = ModbusBaseClient(framer=ModbusSocketFramer) - assert base.transport is None - assert not base.async_connected - - base.connection_made(mock.sentinel.TRANSPORT) - assert base.transport is mock.sentinel.TRANSPORT - base.client_made_connection.assert_called_once_with( # pylint: disable=no-member - base - ) - assert not base.client_lost_connection.call_count # pylint: disable=no-member - - base.connection_lost(mock.sentinel.REASON) - assert base.transport is None - assert not base.client_made_connection.call_count # pylint: disable=no-member - base.client_lost_connection.assert_called_once_with( # pylint: disable=no-member - base - ) - base.raise_future = mock.MagicMock() - request = mock.MagicMock() - base.transaction.addTransaction(request, 1) - base.connection_lost(mock.sentinel.REASON) - base.raise_future.assert_called_once() - call_args = base.raise_future.call_args.args - assert call_args[0] == request - assert isinstance(call_args[1], ConnectionException) - base.transport = mock.MagicMock() - base.transport = None - await base.async_close() - - async def test_client_protocol_receiver(): """Test the client protocol data received""" base = ModbusBaseClient(framer=ModbusSocketFramer) transport = mock.MagicMock() base.connection_made(transport) assert base.transport == transport - assert base.async_connected + assert base.transport data = b"\x00\x00\x12\x34\x00\x06\xff\x01\x01\x02\x00\x04" # setup existing request @@ -410,7 +360,7 @@ async def test_client_protocol_receiver(): result = response.result() assert isinstance(result, pdu_bit_read.ReadCoilsResponse) - base._connected = False # pylint: disable=protected-access + base.transport = None with pytest.raises(ConnectionException): await base._build_response(0x00) # pylint: disable=protected-access @@ -423,7 +373,7 @@ async def test_client_protocol_response(): assert isinstance(excp, ConnectionException) assert not list(base.transaction) - base._connected = True # pylint: disable=protected-access + base.transport = lambda: None base._build_response(0x00) # pylint: disable=protected-access assert len(list(base.transaction)) == 1 diff --git a/test/test_server_task.py b/test/test_server_task.py index 633757b8b..3fac6b9a8 100755 --- a/test/test_server_task.py +++ b/test/test_server_task.py @@ -153,7 +153,7 @@ async def test_async_task_no_server(comm): await asyncio.sleep(0.1) with pytest.raises((asyncio.exceptions.TimeoutError, ConnectionException)): await client.read_coils(1, 1, slave=0x01) - await client.close() + client.close() @pytest.mark.xdist_group(name="server_serialize") @@ -167,13 +167,13 @@ async def test_async_task_ok(comm): client = run_client(**client_args) await client.connect() await asyncio.sleep(0.1) - assert client._connected # pylint: disable=protected-access + assert client.transport rr = await client.read_coils(1, 1, slave=0x01) assert len(rr.bits) == 8 - await client.close() + client.close() await asyncio.sleep(0.1) - assert not client._connected # pylint: disable=protected-access + assert not client.transport await server.ServerAsyncStop() task.cancel() await task @@ -190,23 +190,23 @@ async def test_async_task_reuse(comm): client = run_client(**client_args) await client.connect() await asyncio.sleep(0.1) - assert client._connected # pylint: disable=protected-access + assert client.transport rr = await client.read_coils(1, 1, slave=0x01) assert len(rr.bits) == 8 - await client.close() + client.close() await asyncio.sleep(0.1) - assert not client._connected # pylint: disable=protected-access + assert not client.transport await client.connect() await asyncio.sleep(0.1) - assert client._connected # pylint: disable=protected-access + assert client.transport rr = await client.read_coils(1, 1, slave=0x01) assert len(rr.bits) == 8 - await client.close() + client.close() await asyncio.sleep(0.1) - assert not client._connected # pylint: disable=protected-access + assert not client.transport await server.ServerAsyncStop() task.cancel() @@ -225,7 +225,7 @@ async def test_async_task_server_stop(comm): client = run_client(**client_args, on_reconnect_callback=on_reconnect_callback) await client.connect() - assert client._connected # pylint: disable=protected-access + assert client.transport rr = await client.read_coils(1, 1, slave=0x01) assert len(rr.bits) == 8 on_reconnect_callback.assert_not_called() @@ -236,27 +236,27 @@ async def test_async_task_server_stop(comm): with pytest.raises((ConnectionException, asyncio.exceptions.TimeoutError)): rr = await client.read_coils(1, 1, slave=0x01) - assert not client._connected # pylint: disable=protected-access + assert not client.transport # Server back online task = asyncio.create_task(run_server(**server_args)) await asyncio.sleep(0.1) timer_allowed = 100 - while not client._connected: # pylint: disable=protected-access + while not client.transport: await asyncio.sleep(0.1) timer_allowed -= 1 if not timer_allowed: assert False, "client do not reconnect" - assert client._connected # pylint: disable=protected-access + assert client.transport on_reconnect_callback.assert_called() rr = await client.read_coils(1, 1, slave=0x01) assert len(rr.bits) == 8 - await client.close() + client.close() await asyncio.sleep(0.5) - assert not client._connected # pylint: disable=protected-access + assert not client.transport await server.ServerAsyncStop() await task diff --git a/test/transport/test_basic.py b/test/transport/test_basic.py new file mode 100644 index 000000000..e65cd02e6 --- /dev/null +++ b/test/transport/test_basic.py @@ -0,0 +1,28 @@ +"""Test transport.""" +# import pytest + +from pymodbus.transport.base import BaseTransport + + +class TestTransport: + """Unittest for the transport module.""" + + def test_base_properties(self): + """Test properties.""" + BaseTransport() + # assert not transport.ps_close_comm_on_error + + # with pytest.raises(RuntimeError): + # transport.connection_made(None) + + def test_base_1(self): + """Test properties.""" + + def test_base_2(self): + """Test properties.""" + + def test_base_3(self): + """Test properties.""" + + def test_base_4(self): + """Test properties."""