From cd6a441ad793bb9be27c5139e9b6a64a013f5f91 Mon Sep 17 00:00:00 2001 From: James Hilliard Date: Thu, 18 Jul 2024 11:46:56 -0600 Subject: [PATCH] Add more typing to transaction and base client --- pymodbus/client/base.py | 16 ++++++++-------- pymodbus/transaction.py | 41 +++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/pymodbus/client/base.py b/pymodbus/client/base.py index 90cbecc2df..8900d5f568 100644 --- a/pymodbus/client/base.py +++ b/pymodbus/client/base.py @@ -4,7 +4,7 @@ import asyncio import socket from abc import abstractmethod -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from dataclasses import dataclass from typing import Any, cast @@ -20,7 +20,7 @@ from pymodbus.utilities import ModbusTransactionState -class ModbusBaseClient(ModbusClientMixin[Awaitable[ModbusResponse]]): +class ModbusBaseClient(ModbusClientMixin[Coroutine[Any, Any, ModbusResponse | None]]): """**ModbusBaseClient**. Fixed parameters: @@ -141,7 +141,7 @@ def idle_time(self) -> float: return 0 return self.last_frame_end + self.silent_interval - def execute(self, request: ModbusRequest): + def execute(self, request: ModbusRequest) -> Coroutine[Any, Any, ModbusResponse | None]: """Execute request and get response (call **sync/async**). :param request: The request to process @@ -155,7 +155,7 @@ def execute(self, request: ModbusRequest): # ----------------------------------------------------------------------- # # Merged client methods # ----------------------------------------------------------------------- # - async def async_execute(self, request) -> ModbusResponse: + async def async_execute(self, request) -> ModbusResponse | None: """Execute requests asynchronously.""" request.transaction_id = self.ctx.transaction.getNextTID() packet = self.ctx.framer.buildPacket(request) @@ -183,9 +183,9 @@ async def async_execute(self, request) -> ModbusResponse: f"ERROR: No response received after {self.retries} retries" ) - return resp # type: ignore[return-value] + return resp - def build_response(self, request: ModbusRequest): + def build_response(self, request: ModbusRequest) -> asyncio.Future[ModbusResponse]: """Return a deferred response for the current request.""" my_future: asyncio.Future = asyncio.Future() request.fut = my_future @@ -222,7 +222,7 @@ def __str__(self): ) -class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]): +class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse | bytes | ModbusIOException]): """**ModbusBaseClient**. Fixed parameters: @@ -336,7 +336,7 @@ def idle_time(self) -> float: return 0 return self.last_frame_end + self.silent_interval - def execute(self, request: ModbusRequest) -> ModbusResponse: + def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: """Execute request and get response (call **sync/async**). :param request: The request to process diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index bbb187efe0..ad13fb1c61 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -29,7 +29,7 @@ ModbusTlsFramer, ) from pymodbus.logging import Log -from pymodbus.pdu import ModbusRequest +from pymodbus.pdu import ModbusRequest, ModbusResponse from pymodbus.transport import CommType from pymodbus.utilities import ModbusTransactionState, hexlify_packets @@ -167,13 +167,13 @@ def _set_adu_size(self): else: self.base_adu_size = -1 - def _calculate_response_length(self, expected_pdu_size): + def _calculate_response_length(self, expected_pdu_size: int) -> int | None: """Calculate response length.""" if self.base_adu_size == -1: return None return self.base_adu_size + expected_pdu_size - def _calculate_exception_length(self): + def _calculate_exception_length(self) -> int | None: """Return the length of the Modbus Exception Response according to the type of Framer.""" if isinstance(self.client.framer, (ModbusSocketFramer, ModbusTlsFramer)): return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1) @@ -183,7 +183,9 @@ def _calculate_exception_length(self): return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1) return None - def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_udp=False): + def _validate_response( + self, request: ModbusRequest, response: bytes | int, exp_resp_len: int | None, is_udp=False + ) -> bool: """Validate Incoming response against request. :param request: Request sent @@ -208,7 +210,7 @@ def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_ return mbap.get("length") == exp_resp_len return True - def execute(self, request: ModbusRequest): # noqa: C901 + def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: # noqa: C901 """Start the producer to send the next request to consumer.write(Frame(request)).""" with self._transaction_lock: try: @@ -333,7 +335,9 @@ def execute(self, request: ModbusRequest): # noqa: C901 self.client.close() return exc - def _retry_transaction(self, retries, reason, packet, response_length, full=False): + def _retry_transaction( + self, retries: int, reason: str, request: ModbusRequest, response_length: int | None, full=False + ) -> tuple[bytes, str | Exception | None]: """Retry transaction.""" Log.debug("Retry on {} response - {}", reason, retries) Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"') @@ -350,9 +354,11 @@ def _retry_transaction(self, retries, reason, packet, response_length, full=Fals if response_length == in_waiting: result = self._recv(response_length, full) return result, None - return self._transact(packet, response_length, full=full) + return self._transact(request, response_length, full=full) - def _transact(self, request: ModbusRequest, response_length, full=False, broadcast=False): + def _transact( + self, request: ModbusRequest, response_length: int | None, full=False, broadcast=False + ) -> tuple[bytes, str | Exception | None]: """Do a Write and Read transaction. :param packet: packet to be sent @@ -368,16 +374,13 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca packet = self.client.framer.buildPacket(request) Log.debug("SEND: {}", packet, ":hex") size = self._send(packet) - if ( - isinstance(size, bytes) - and self.client.state == ModbusTransactionState.RETRYING - ): + if self.client.state == ModbusTransactionState.RETRYING: Log.debug( "Changing transaction state from " '"RETRYING" to "PROCESSING REPLY"' ) self.client.state = ModbusTransactionState.PROCESSING_REPLY - return size, None + return b"", None if self.client.comm_params.handle_local_echo is True: if self._recv(size, full) != packet: return b"", "Wrong local echo" @@ -405,11 +408,11 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca result = b"" return result, last_exception - def _send(self, packet: bytes, _retrying=False): + def _send(self, packet: bytes, _retrying=False) -> int: """Send.""" return self.client.framer.sendPacket(packet) - def _recv(self, expected_response_length, full) -> bytes: # noqa: C901 + def _recv(self, expected_response_length: int | None, full: bool) -> bytes: # noqa: C901 """Receive.""" total = None if not full: @@ -420,8 +423,10 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901 min_size = 4 elif isinstance(self.client.framer, ModbusAsciiFramer): min_size = 5 - else: + elif expected_response_length: min_size = expected_response_length + else: + min_size = 0 read_min = self.client.framer.recvPacket(min_size) if len(read_min) != min_size: @@ -463,6 +468,8 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901 expected_response_length -= min_size total = expected_response_length + min_size else: + if exception_length is None: + exception_length = 0 expected_response_length = exception_length - min_size total = expected_response_length + min_size else: @@ -470,6 +477,8 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901 else: read_min = b"" total = expected_response_length + if expected_response_length is None: + expected_response_length = 0 result = self.client.framer.recvPacket(expected_response_length) result = read_min + result actual = len(result)