Skip to content

Commit

Permalink
Add more typing to transaction and base client
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Jul 18, 2024
1 parent 0e2c7e9 commit a1eb85b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 25 deletions.
16 changes: 8 additions & 8 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, cast, Coroutine

from pymodbus.client.mixin import ModbusClientMixin
from pymodbus.client.modbusclientprotocol import ModbusClientProtocol
Expand All @@ -20,7 +20,7 @@
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin[Awaitable[ModbusResponse]]):
class ModbusBaseClient(ModbusClientMixin[Coroutine[Any, Any, ModbusResponse | None]]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -141,7 +141,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest):
def execute(self, request: ModbusRequest) -> Coroutine[Any, Any, ModbusResponse | None]:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand All @@ -155,7 +155,7 @@ def execute(self, request: ModbusRequest):
# ----------------------------------------------------------------------- #
# Merged client methods
# ----------------------------------------------------------------------- #
async def async_execute(self, request) -> ModbusResponse:
async def async_execute(self, request) -> ModbusResponse | None:
"""Execute requests asynchronously."""
request.transaction_id = self.ctx.transaction.getNextTID()
packet = self.ctx.framer.buildPacket(request)
Expand Down Expand Up @@ -183,9 +183,9 @@ async def async_execute(self, request) -> ModbusResponse:
f"ERROR: No response received after {self.retries} retries"
)

return resp # type: ignore[return-value]
return resp

def build_response(self, request: ModbusRequest):
def build_response(self, request: ModbusRequest) -> asyncio.Future[ModbusResponse]:
"""Return a deferred response for the current request."""
my_future: asyncio.Future = asyncio.Future()
request.fut = my_future
Expand Down Expand Up @@ -222,7 +222,7 @@ def __str__(self):
)


class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]):
class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse | bytes | ModbusIOException]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -336,7 +336,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest) -> ModbusResponse:
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand Down
43 changes: 26 additions & 17 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
from contextlib import suppress
from threading import RLock
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple

from pymodbus.exceptions import (
ConnectionException,
Expand All @@ -29,7 +29,7 @@
ModbusTlsFramer,
)
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transport import CommType
from pymodbus.utilities import ModbusTransactionState, hexlify_packets

Expand Down Expand Up @@ -167,13 +167,13 @@ def _set_adu_size(self):
else:
self.base_adu_size = -1

def _calculate_response_length(self, expected_pdu_size):
def _calculate_response_length(self, expected_pdu_size: int) -> int | None:
"""Calculate response length."""
if self.base_adu_size == -1:
return None
return self.base_adu_size + expected_pdu_size

def _calculate_exception_length(self):
def _calculate_exception_length(self) -> int | None:
"""Return the length of the Modbus Exception Response according to the type of Framer."""
if isinstance(self.client.framer, (ModbusSocketFramer, ModbusTlsFramer)):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
Expand All @@ -183,7 +183,9 @@ def _calculate_exception_length(self):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
return None

def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_udp=False):
def _validate_response(
self, request: ModbusRequest, response: bytes | int, exp_resp_len: int | None, is_udp=False
) -> bool:
"""Validate Incoming response against request.
:param request: Request sent
Expand All @@ -208,7 +210,7 @@ def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_
return mbap.get("length") == exp_resp_len
return True

def execute(self, request: ModbusRequest): # noqa: C901
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: # noqa: C901
"""Start the producer to send the next request to consumer.write(Frame(request))."""
with self._transaction_lock:
try:
Expand Down Expand Up @@ -333,7 +335,9 @@ def execute(self, request: ModbusRequest): # noqa: C901
self.client.close()
return exc

def _retry_transaction(self, retries, reason, packet, response_length, full=False):
def _retry_transaction(
self, retries: int, reason: str, request: ModbusRequest, response_length: int | None, full=False
) -> Tuple[bytes, str | Exception | None]:
"""Retry transaction."""
Log.debug("Retry on {} response - {}", reason, retries)
Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"')
Expand All @@ -350,9 +354,11 @@ def _retry_transaction(self, retries, reason, packet, response_length, full=Fals
if response_length == in_waiting:
result = self._recv(response_length, full)
return result, None
return self._transact(packet, response_length, full=full)
return self._transact(request, response_length, full=full)

def _transact(self, request: ModbusRequest, response_length, full=False, broadcast=False):
def _transact(
self, request: ModbusRequest, response_length: int | None, full=False, broadcast=False
) -> Tuple[bytes, str | Exception | None]:
"""Do a Write and Read transaction.
:param packet: packet to be sent
Expand All @@ -368,16 +374,13 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
packet = self.client.framer.buildPacket(request)
Log.debug("SEND: {}", packet, ":hex")
size = self._send(packet)
if (
isinstance(size, bytes)
and self.client.state == ModbusTransactionState.RETRYING
):
if self.client.state == ModbusTransactionState.RETRYING:
Log.debug(
"Changing transaction state from "
'"RETRYING" to "PROCESSING REPLY"'
)
self.client.state = ModbusTransactionState.PROCESSING_REPLY
return size, None
return b"", None
if self.client.comm_params.handle_local_echo is True:
if self._recv(size, full) != packet:
return b"", "Wrong local echo"
Expand Down Expand Up @@ -405,11 +408,11 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
result = b""
return result, last_exception

def _send(self, packet: bytes, _retrying=False):
def _send(self, packet: bytes, _retrying=False) -> int:
"""Send."""
return self.client.framer.sendPacket(packet)

def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
def _recv(self, expected_response_length: int | None, full: bool) -> bytes: # noqa: C901
"""Receive."""
total = None
if not full:
Expand All @@ -420,8 +423,10 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
min_size = 4
elif isinstance(self.client.framer, ModbusAsciiFramer):
min_size = 5
else:
elif expected_response_length:
min_size = expected_response_length
else:
min_size = 0

read_min = self.client.framer.recvPacket(min_size)
if len(read_min) != min_size:
Expand Down Expand Up @@ -463,13 +468,17 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
expected_response_length -= min_size
total = expected_response_length + min_size
else:
if exception_length is None:
exception_length = 0
expected_response_length = exception_length - min_size
total = expected_response_length + min_size
else:
total = expected_response_length
else:
read_min = b""
total = expected_response_length
if expected_response_length is None:
expected_response_length = 0
result = self.client.framer.recvPacket(expected_response_length)
result = read_min + result
actual = len(result)
Expand Down

0 comments on commit a1eb85b

Please sign in to comment.