Skip to content

Commit

Permalink
Add more typing to transaction and base client
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Jul 19, 2024
1 parent 72e3399 commit 3ddd6f6
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
18 changes: 9 additions & 9 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import asyncio
import socket
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from dataclasses import dataclass
from typing import Any, cast
from typing import Any, Union, cast

from pymodbus.client.mixin import ModbusClientMixin
from pymodbus.client.modbusclientprotocol import ModbusClientProtocol
Expand All @@ -20,7 +20,7 @@
from pymodbus.utilities import ModbusTransactionState


class ModbusBaseClient(ModbusClientMixin[Awaitable[ModbusResponse]]):
class ModbusBaseClient(ModbusClientMixin[Coroutine[Any, Any, Union[ModbusResponse, None]]]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -144,7 +144,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest):
def execute(self, request: ModbusRequest) -> Coroutine[Any, Any, ModbusResponse | None]:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand All @@ -158,7 +158,7 @@ def execute(self, request: ModbusRequest):
# ----------------------------------------------------------------------- #
# Merged client methods
# ----------------------------------------------------------------------- #
async def async_execute(self, request) -> ModbusResponse:
async def async_execute(self, request) -> ModbusResponse | None:
"""Execute requests asynchronously."""
request.transaction_id = self.ctx.transaction.getNextTID()
packet = self.ctx.framer.buildPacket(request)
Expand Down Expand Up @@ -186,9 +186,9 @@ async def async_execute(self, request) -> ModbusResponse:
f"ERROR: No response received after {self.retries} retries"
)

return resp # type: ignore[return-value]
return resp

def build_response(self, request: ModbusRequest):
def build_response(self, request: ModbusRequest) -> asyncio.Future[ModbusResponse]:
"""Return a deferred response for the current request."""
my_future: asyncio.Future = asyncio.Future()
request.fut = my_future
Expand Down Expand Up @@ -225,7 +225,7 @@ def __str__(self):
)


class ModbusBaseSyncClient(ModbusClientMixin[ModbusResponse]):
class ModbusBaseSyncClient(ModbusClientMixin[Union[ModbusResponse, bytes, ModbusIOException]]):
"""**ModbusBaseClient**.
Fixed parameters:
Expand Down Expand Up @@ -346,7 +346,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest) -> ModbusResponse:
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand Down
2 changes: 1 addition & 1 deletion pymodbus/framer/old_framer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def sendPacket(self, message: bytes):
"""
return self.client.send(message)

def recvPacket(self, size: int) -> bytes:
def recvPacket(self, size: int | None) -> bytes:
"""Receive packet from the bus.
With specified len
Expand Down
37 changes: 19 additions & 18 deletions pymodbus/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ModbusTlsFramer,
)
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transport import CommType
from pymodbus.utilities import ModbusTransactionState, hexlify_packets

Expand Down Expand Up @@ -167,13 +167,13 @@ def _set_adu_size(self):
else:
self.base_adu_size = -1

def _calculate_response_length(self, expected_pdu_size):
def _calculate_response_length(self, expected_pdu_size: int) -> int | None:
"""Calculate response length."""
if self.base_adu_size == -1:
return None
return self.base_adu_size + expected_pdu_size

def _calculate_exception_length(self):
def _calculate_exception_length(self) -> int | None:
"""Return the length of the Modbus Exception Response according to the type of Framer."""
if isinstance(self.client.framer, (ModbusSocketFramer, ModbusTlsFramer)):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
Expand All @@ -183,7 +183,9 @@ def _calculate_exception_length(self):
return self.base_adu_size + 2 # Fcode(1), ExceptionCode(1)
return None

def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_udp=False):
def _validate_response(
self, request: ModbusRequest, response: bytes | int, exp_resp_len: int | None, is_udp=False
) -> bool:
"""Validate Incoming response against request.
:param request: Request sent
Expand All @@ -208,7 +210,7 @@ def _validate_response(self, request: ModbusRequest, response, exp_resp_len, is_
return mbap.get("length") == exp_resp_len
return True

def execute(self, request: ModbusRequest): # noqa: C901
def execute(self, request: ModbusRequest) -> ModbusResponse | bytes | ModbusIOException: # noqa: C901
"""Start the producer to send the next request to consumer.write(Frame(request))."""
with self._transaction_lock:
try:
Expand Down Expand Up @@ -333,7 +335,9 @@ def execute(self, request: ModbusRequest): # noqa: C901
self.client.close()
return exc

def _retry_transaction(self, retries, reason, packet, response_length, full=False):
def _retry_transaction(
self, retries: int, reason: str, request: ModbusRequest, response_length: int | None, full=False
) -> tuple[bytes, str | Exception | None]:
"""Retry transaction."""
Log.debug("Retry on {} response - {}", reason, retries)
Log.debug('Changing transaction state from "WAITING_FOR_REPLY" to "RETRYING"')
Expand All @@ -350,9 +354,11 @@ def _retry_transaction(self, retries, reason, packet, response_length, full=Fals
if response_length == in_waiting:
result = self._recv(response_length, full)
return result, None
return self._transact(packet, response_length, full=full)
return self._transact(request, response_length, full=full)

def _transact(self, request: ModbusRequest, response_length, full=False, broadcast=False):
def _transact(
self, request: ModbusRequest, response_length: int | None, full=False, broadcast=False
) -> tuple[bytes, str | Exception | None]:
"""Do a Write and Read transaction.
:param packet: packet to be sent
Expand All @@ -368,16 +374,12 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
packet = self.client.framer.buildPacket(request)
Log.debug("SEND: {}", packet, ":hex")
size = self._send(packet)
if (
isinstance(size, bytes)
and self.client.state == ModbusTransactionState.RETRYING
):
if size and self.client.state == ModbusTransactionState.RETRYING:
Log.debug(
"Changing transaction state from "
'"RETRYING" to "PROCESSING REPLY"'
)
self.client.state = ModbusTransactionState.PROCESSING_REPLY
return size, None
if self.client.comm_params.handle_local_echo is True:
if self._recv(size, full) != packet:
return b"", "Wrong local echo"
Expand Down Expand Up @@ -405,23 +407,22 @@ def _transact(self, request: ModbusRequest, response_length, full=False, broadca
result = b""
return result, last_exception

def _send(self, packet: bytes, _retrying=False):
def _send(self, packet: bytes, _retrying=False) -> int:
"""Send."""
return self.client.framer.sendPacket(packet)

def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
def _recv(self, expected_response_length: int | None, full: bool) -> bytes: # noqa: C901
"""Receive."""
total = None
if not full:
exception_length = self._calculate_exception_length()
min_size = expected_response_length
if isinstance(self.client.framer, ModbusSocketFramer):
min_size = 8
elif isinstance(self.client.framer, ModbusRtuFramer):
min_size = 4
elif isinstance(self.client.framer, ModbusAsciiFramer):
min_size = 5
else:
min_size = expected_response_length

read_min = self.client.framer.recvPacket(min_size)
if len(read_min) != min_size:
Expand Down Expand Up @@ -462,7 +463,7 @@ def _recv(self, expected_response_length, full) -> bytes: # noqa: C901
if expected_response_length is not None:
expected_response_length -= min_size
total = expected_response_length + min_size
else:
if func_code >= 0x80 and exception_length:
expected_response_length = exception_length - min_size
total = expected_response_length + min_size
else:
Expand Down

0 comments on commit 3ddd6f6

Please sign in to comment.