Skip to content

Commit

Permalink
Add more typing to transaction and base client
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Jul 18, 2024
1 parent 0e2c7e9 commit cd6a441
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 24 deletions.
16 changes: 8 additions & 8 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import socket
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from typing import Any, cast

Expand All @@ -20,7 +20,7 @@
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin[Awaitable[ModbusResponse]]):
class ModbusBaseClient(ModbusClientMixin[Coroutine[Any, Any, ModbusResponse | None]]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -141,7 +141,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest):
def execute(self, request: ModbusRequest) -> Coroutine[Any, Any, ModbusResponse | None]:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand All @@ -155,7 +155,7 @@ def execute(self, request: ModbusRequest):
# ----------------------------------------------------------------------- #
# Merged client methods
# ----------------------------------------------------------------------- #
async def async_execute(self, request) -> ModbusResponse:
async def async_execute(self, request) -> ModbusResponse | None:
"""Execute requests asynchronously."""
request.transaction_id = self.ctx.transaction.getNextTID()
packet = self.ctx.framer.buildPacket(request)
Expand Down Expand Up @@ -183,9 +183,9 @@ async def async_execute(self, request) -> ModbusResponse:
f"ERROR: No response received after {self.retries} retries"
)

return resp # type: ignore[return-value]
return resp

def build_response(self, request: ModbusRequest):
def build_response(self, request: ModbusRequest) -> asyncio.Future[ModbusResponse]:
"""Return a deferred response for the current request."""
my_future: asyncio.Future = asyncio.Future()
request.fut = my_future
Expand Down Expand Up @@ -222,7 +222,7 @@ def __str__(self):
)


class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]):
class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse | bytes | ModbusIOException]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -336,7 +336,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest) -> ModbusResponse:
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand Down
41 changes: 25 additions & 16 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ModbusTlsFramer,
)
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transport import CommType
from pymodbus.utilities import ModbusTransactionState, hexlify_packets

Expand Down Expand Up @@ -167,13 +167,13 @@ def _set_adu_size(self):
else:
self.base_adu_size = -1

def _calculate_response_length(self, expected_pdu_size):
def _calculate_response_length(self, expected_pdu_size: int) -> int | None:
"""Calculate response length."""
if self.base_adu_size == -1:
return None
return self.base_adu_size + expected_pdu_size

def _calculate_exception_length(self):
def _calculate_exception_length(self) -> int | None:
"""Return the length of the Modbus Exception Response according to the type of Framer."""
if isinstance(self.client.framer, (ModbusSocketFramer, ModbusTlsFramer)):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
Expand All @@ -183,7 +183,9 @@ def _calculate_exception_length(self):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
return None

def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_udp=False):
def _validate_response(
self, request: ModbusRequest, response: bytes | int, exp_resp_len: int | None, is_udp=False
) -> bool:
"""Validate Incoming response against request.
:param request: Request sent
Expand All @@ -208,7 +210,7 @@ def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_
return mbap.get("length") == exp_resp_len
return True

def execute(self, request: ModbusRequest): # noqa: C901
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: # noqa: C901
"""Start the producer to send the next request to consumer.write(Frame(request))."""
with self._transaction_lock:
try:
Expand Down Expand Up @@ -333,7 +335,9 @@ def execute(self, request: ModbusRequest): # noqa: C901
self.client.close()
return exc

def _retry_transaction(self, retries, reason, packet, response_length, full=False):
def _retry_transaction(
self, retries: int, reason: str, request: ModbusRequest, response_length: int | None, full=False
) -> tuple[bytes, str | Exception | None]:
"""Retry transaction."""
Log.debug("Retry on {} response - {}", reason, retries)
Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"')
Expand All @@ -350,9 +354,11 @@ def _retry_transaction(self, retries, reason, packet, response_length, full=Fals
if response_length == in_waiting:
result = self._recv(response_length, full)
return result, None
return self._transact(packet, response_length, full=full)
return self._transact(request, response_length, full=full)

def _transact(self, request: ModbusRequest, response_length, full=False, broadcast=False):
def _transact(
self, request: ModbusRequest, response_length: int | None, full=False, broadcast=False
) -> tuple[bytes, str | Exception | None]:
"""Do a Write and Read transaction.
:param packet: packet to be sent
Expand All @@ -368,16 +374,13 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
packet = self.client.framer.buildPacket(request)
Log.debug("SEND: {}", packet, ":hex")
size = self._send(packet)
if (
isinstance(size, bytes)
and self.client.state == ModbusTransactionState.RETRYING
):
if self.client.state == ModbusTransactionState.RETRYING:
Log.debug(
"Changing transaction state from "
'"RETRYING" to "PROCESSING REPLY"'
)
self.client.state = ModbusTransactionState.PROCESSING_REPLY
return size, None
return b"", None
if self.client.comm_params.handle_local_echo is True:
if self._recv(size, full) != packet:
return b"", "Wrong local echo"
Expand Down Expand Up @@ -405,11 +408,11 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
result = b""
return result, last_exception

def _send(self, packet: bytes, _retrying=False):
def _send(self, packet: bytes, _retrying=False) -> int:
"""Send."""
return self.client.framer.sendPacket(packet)

def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
def _recv(self, expected_response_length: int | None, full: bool) -> bytes: # noqa: C901
"""Receive."""
total = None
if not full:
Expand All @@ -420,8 +423,10 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
min_size = 4
elif isinstance(self.client.framer, ModbusAsciiFramer):
min_size = 5
else:
elif expected_response_length:
min_size = expected_response_length
else:
min_size = 0

read_min = self.client.framer.recvPacket(min_size)
if len(read_min) != min_size:
Expand Down Expand Up @@ -463,13 +468,17 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
expected_response_length -= min_size
total = expected_response_length + min_size
else:
if exception_length is None:
exception_length = 0
expected_response_length = exception_length - min_size
total = expected_response_length + min_size
else:
total = expected_response_length
else:
read_min = b""
total = expected_response_length
if expected_response_length is None:
expected_response_length = 0
result = self.client.framer.recvPacket(expected_response_length)
result = read_min + result
actual = len(result)
Expand Down

0 comments on commit cd6a441

Please sign in to comment.