diff --git a/pymodbus/transport/transport.py b/pymodbus/transport/transport.py index e6b09ea1f..1e978cd5d 100644 --- a/pymodbus/transport/transport.py +++ b/pymodbus/transport/transport.py @@ -266,9 +266,11 @@ async def transport_listen(self) -> bool: self.loop = asyncio.get_running_loop() self.is_closing = False try: - self.transport = await self.call_create() - if isinstance(self.transport, tuple): - self.transport = self.transport[0] + _transport = await self.call_create() + if isinstance(_transport, tuple): + self.transport = _transport[0] + else: + self.transport = _transport except OSError as exc: Log.warning("Failed to start server {}", exc) # self.transport_close(intern=True) @@ -435,7 +437,7 @@ def transport_close(self, intern: bool = False, reconnect: bool = False) -> None def reset_delay(self) -> None: """Reset wait time before next reconnect to minimal period.""" - self.reconnect_delay_current = self.comm_params.reconnect_delay + self.reconnect_delay_current = self.comm_params.reconnect_delay or 0 def is_active(self) -> bool: """Return true if connected/listening.""" @@ -469,6 +471,12 @@ def handle_new_connection(self) -> ModbusProtocol: async def do_reconnect(self) -> None: """Handle reconnect as a task.""" + if not ( + self.comm_params.reconnect_delay and self.comm_params.reconnect_delay_max + ): + raise AssertionError( + "do_reconnect should not be called if reconnect_delay is None" + ) try: self.reconnect_delay_current = self.comm_params.reconnect_delay while True: @@ -519,7 +527,7 @@ def __init__(self, protocol: ModbusProtocol, listen: int | None = None) -> None: asyncio.DatagramTransport.__init__(self) asyncio.Transport.__init__(self) self.protocol: ModbusProtocol = protocol - self.other_modem: NullModem = None + self.other_modem: NullModem | None = None self.listen = listen self.manipulator: Callable[[bytes], list[bytes]] | None = None self._is_closing = False @@ -605,6 +613,8 @@ def sendto(self, data: bytes, _addr: Any = None) -> None: def write(self, data: bytes) -> None: """Send data.""" + if not self.other_modem: + raise AssertionError("Missing other_modem") if not self.manipulator: self.other_modem.protocol.data_received(data) return diff --git a/pymodbus/transport/transport_serial.py b/pymodbus/transport/transport_serial.py index 63c1affa5..b3ca34b38 100644 --- a/pymodbus/transport/transport_serial.py +++ b/pymodbus/transport/transport_serial.py @@ -20,7 +20,7 @@ def __init__(self, loop, protocol, *args, **kwargs) -> None: super().__init__() self.async_loop = loop self._protocol: asyncio.BaseProtocol = protocol - self.sync_serial = serial.serial_for_url(*args, **kwargs) + self.sync_serial: serial.Serial | None = serial.serial_for_url(*args, **kwargs) self._write_buffer: list[bytes] = [] self.poll_task: asyncio.Task | None = None self._poll_wait_time = 0.0005 @@ -59,12 +59,12 @@ def close(self, exc: Exception | None = None) -> None: def write(self, data) -> None: """Write some data to the transport.""" self._write_buffer.append(data) - if not self.poll_task: + if not self.poll_task and self.sync_serial: self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready) def flush(self) -> None: """Clear output buffer and stops any more data being written.""" - if not self.poll_task: + if not self.poll_task and self.sync_serial: self.async_loop.remove_writer(self.sync_serial.fileno()) self._write_buffer.clear() diff --git a/pyproject.toml b/pyproject.toml index 17cd7f42d..d7a1c60b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,7 +187,7 @@ overgeneral-exceptions = "builtins.Exception" bad-functions = "map,input" [tool.mypy] -strict_optional = false +strict_optional = true show_error_codes = true local_partial_types = true strict_equality = true