Skip to content

Commit

Permalink
Refactor transaction handling to better separate async and sync code.
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Jul 14, 2024
1 parent 5e146c0 commit db108b0
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 160 deletions.
44 changes: 20 additions & 24 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import socket
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, cast
Expand All @@ -14,7 +15,7 @@
from pymodbus.framer import FRAMER_NAME_TO_CLASS, FramerType, ModbusFramer
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import ModbusTransactionManager
from pymodbus.transaction import SyncModbusTransactionManager
from pymodbus.transport import CommParams
from pymodbus.utilities import ModbusTransactionState

Expand Down Expand Up @@ -53,7 +54,6 @@ def __init__(
framer: FramerType,
timeout: float = 3,
retries: int = 3,
retry_on_empty: bool = False,
broadcast_enable: bool = False,
reconnect_delay: float = 0.1,
reconnect_delay_max: float = 300,
Expand Down Expand Up @@ -81,8 +81,6 @@ def __init__(
stopbits=kwargs.get("stopbits", None),
handle_local_echo=kwargs.get("handle_local_echo", False),
),
retries,
retry_on_empty,
on_connect_callback,
)
self.no_resend_on_retry = no_resend_on_retry
Expand Down Expand Up @@ -143,7 +141,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest | None = None):
def execute(self, request: ModbusRequest):
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand All @@ -165,7 +163,7 @@ async def async_execute(self, request) -> ModbusResponse:
count = 0
while count <= self.retries:
async with self._lock:
req = self.build_response(request.transaction_id)
req = self.build_response(request)
if not count or not self.no_resend_on_retry:
self.ctx.framer.resetFrame()
self.ctx.send(packet)
Expand All @@ -187,25 +185,17 @@ async def async_execute(self, request) -> ModbusResponse:

return resp # type: ignore[return-value]

def build_response(self, tid):
def build_response(self, request: ModbusRequest):
"""Return a deferred response for the current request."""
my_future: asyncio.Future = asyncio.Future()
request.fut = my_future
if not self.ctx.transport:
if not my_future.done():
my_future.set_exception(ConnectionException("Client is not connected"))
else:
self.ctx.transaction.addTransaction(my_future, tid)
self.ctx.transaction.addTransaction(request)
return my_future

# ----------------------------------------------------------------------- #
# Internal methods
# ----------------------------------------------------------------------- #
def recv(self, size):
"""Receive data.
:meta private:
"""

# ----------------------------------------------------------------------- #
# The magic methods
# ----------------------------------------------------------------------- #
Expand Down Expand Up @@ -309,10 +299,10 @@ def __init__(
self.slaves: list[int] = []

# Common variables.
self.framer = FRAMER_NAME_TO_CLASS.get(
self.framer: ModbusFramer = FRAMER_NAME_TO_CLASS.get(
framer, cast(type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = ModbusTransactionManager(
self.transaction = SyncModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty, **kwargs
)
self.reconnect_delay_current = self.params.reconnect_delay or 0
Expand Down Expand Up @@ -346,7 +336,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest | None = None) -> ModbusResponse:
def execute(self, request: ModbusRequest) -> ModbusResponse:
"""Execute request and get response (call **sync/async**).
:param request: The request to process
Expand All @@ -360,22 +350,28 @@ def execute(self, request: ModbusRequest | None = None) -> ModbusResponse:
# ----------------------------------------------------------------------- #
# Internal methods
# ----------------------------------------------------------------------- #
def send(self, request):
def _start_send(self):
"""Send request.
:meta private:
"""
if self.state != ModbusTransactionState.RETRYING:
Log.debug('New Transaction state "SENDING"')
self.state = ModbusTransactionState.SENDING
return request

def recv(self, size):
@abstractmethod
def send(self, request: bytes) -> int:
"""Send request.
:meta private:
"""

@abstractmethod
def recv(self, size: int | None) -> bytes:
"""Receive data.
:meta private:
"""
return size

@classmethod
def get_address_family(cls, address):
Expand Down
11 changes: 4 additions & 7 deletions pymodbus/client/modbusclientprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def __init__(
self,
framer: FramerType,
params: CommParams,
retries: int,
retry_on_empty: bool,
on_connect_callback: Callable[[bool], None] | None = None,
) -> None:
"""Initialize a client instance."""
Expand All @@ -45,17 +43,16 @@ def __init__(
self.framer = FRAMER_NAME_TO_CLASS.get(
framer, cast(type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = ModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty
)
self.transaction = ModbusTransactionManager()

def _handle_response(self, reply, **_kwargs):
"""Handle the processed response and link to correct deferred."""
if reply is not None:
tid = reply.transaction_id
if handler := self.transaction.getTransaction(tid):
if not handler.done():
handler.set_result(reply)
reply.request = handler
if not handler.fut.done():
handler.fut.set_result(reply)
else:
Log.debug("Unrequested message: {}", reply, ":str")

Expand Down
16 changes: 8 additions & 8 deletions pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def connected(self):
"""Connect internal."""
return self.connect()

def connect(self):
def connect(self) -> bool:
"""Connect to the modbus serial server."""
if self.socket:
return True
Expand Down Expand Up @@ -227,25 +227,26 @@ def _in_waiting(self):
"""Return waiting bytes."""
return getattr(self.socket, "in_waiting") if hasattr(self.socket, "in_waiting") else getattr(self.socket, "inWaiting")()

def send(self, request):
def send(self, request: bytes) -> int:
"""Send data on the underlying socket.
If receive buffer still holds some data then flush it.
Sleep if last send finished less than 3.5 character times ago.
"""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
if waitingbytes := self._in_waiting():
result = self.socket.read(waitingbytes)
Log.warning("Cleanup recv buffer before send: {}", result, ":hex")
size = self.socket.write(request)
if (size := self.socket.write(request)) is None:
size = 0
return size
return 0

def _wait_for_data(self):
def _wait_for_data(self) -> int:
"""Wait for data."""
size = 0
more_data = False
Expand All @@ -264,9 +265,8 @@ def _wait_for_data(self):
time.sleep(self._recv_interval)
return size

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor."""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))
if size is None:
Expand All @@ -276,7 +276,7 @@ def recv(self, size):
result = self.socket.read(size)
return result

def is_socket_open(self):
def is_socket_open(self) -> bool:
"""Check if socket is open."""
if self.socket:
return self.socket.is_open
Expand Down
9 changes: 4 additions & 5 deletions pymodbus/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,15 @@ def close(self):

def send(self, request):
"""Send data on the underlying socket."""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
return self.socket.send(request)
return 0

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor."""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))

Expand Down Expand Up @@ -241,7 +240,7 @@ def recv(self, size):

return b"".join(data)

def _handle_abrupt_socket_close(self, size, data, duration):
def _handle_abrupt_socket_close(self, size: int | None, data: list[bytes], duration: float) -> bytes:
"""Handle unexpected socket close by remote end.
Intended to be invoked after determining that the remote end
Expand Down Expand Up @@ -271,7 +270,7 @@ def _handle_abrupt_socket_close(self, size, data, duration):
msg += " without response from slave before it closed connection"
raise ConnectionException(msg)

def is_socket_open(self):
def is_socket_open(self) -> bool:
"""Check if socket is open."""
return self.socket is not None

Expand Down
9 changes: 5 additions & 4 deletions pymodbus/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def close(self):
"""
self.socket = None

def send(self, request):
def send(self, request: bytes) -> int:
"""Send data on the underlying socket.
:meta private:
"""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
Expand All @@ -183,14 +183,15 @@ def send(self, request):
)
return 0

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor.
:meta private:
"""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))
if size is None:
size = 0
return self.socket.recvfrom(size)[0]

def is_socket_open(self):
Expand Down
26 changes: 17 additions & 9 deletions pymodbus/framer/old_framer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from __future__ import annotations

import time
from typing import Any
from typing import TYPE_CHECKING, Any

from pymodbus.factory import ClientDecoder, ServerDecoder
from pymodbus.framer.base import FramerBase
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest


if TYPE_CHECKING:
from pymodbus.client.base import ModbusBaseSyncClient

# Unit ID, Function Code
BYTE_ORDER = ">"
FRAME_HEADER = "BB"
Expand All @@ -29,24 +33,28 @@ class ModbusFramer:
def __init__(
self,
decoder: ClientDecoder | ServerDecoder,
client,
client: ModbusBaseSyncClient,
) -> None:
"""Initialize a new instance of the framer.
:param decoder: The decoder implementation to use
"""
self.decoder = decoder
self.client = client
self._header: dict[str, Any] = {
self._header: dict[str, Any]
self._reset_header()
self._buffer = b""
self.message_handler: FramerBase

def _reset_header(self) -> None:
self._header = {
"lrc": "0000",
"len": 0,
"uid": 0x00,
"tid": 0,
"pid": 0,
"crc": b"\x00\x00",
}
self._buffer = b""
self.message_handler: FramerBase

def _validate_slave_id(self, slaves: list, single: bool) -> bool:
"""Validate if the received data is valid for the client.
Expand All @@ -63,7 +71,7 @@ def _validate_slave_id(self, slaves: list, single: bool) -> bool:
return True
return self._header["uid"] in slaves

def sendPacket(self, message):
def sendPacket(self, message: bytes):
"""Send packets on the bus.
With 3.5char delay between frames
Expand All @@ -72,7 +80,7 @@ def sendPacket(self, message):
"""
return self.client.send(message)

def recvPacket(self, size):
def recvPacket(self, size: int) -> bytes:
"""Receive packet from the bus.
With specified len
Expand Down Expand Up @@ -117,7 +125,7 @@ def populateResult(self, result):
result.transaction_id = self._header.get("tid", 0)
result.protocol_id = self._header.get("pid", 0)

def processIncomingPacket(self, data, callback, slave, **kwargs):
def processIncomingPacket(self, data: bytes, callback, slave, **kwargs):
"""Process new packet pattern.
This takes in a new request packet, adds it to the current
Expand Down Expand Up @@ -150,7 +158,7 @@ def frameProcessIncomingPacket(
) -> None:
"""Process new packet pattern."""

def buildPacket(self, message) -> bytes:
def buildPacket(self, message: ModbusRequest) -> bytes:
"""Create a ready to send modbus packet.
:param message: The populated request/response to send
Expand Down
Loading

0 comments on commit db108b0

Please sign in to comment.