Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor transaction handling to better separate async and sync code. #2232

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions pymodbus/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import socket
from abc import abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, cast
Expand All @@ -14,7 +15,7 @@
from pymodbus.framer import FRAMER_NAME_TO_CLASS, FramerType, ModbusFramer
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest, ModbusResponse
from pymodbus.transaction import ModbusTransactionManager
from pymodbus.transaction import SyncModbusTransactionManager
from pymodbus.transport import CommParams
from pymodbus.utilities import ModbusTransactionState

Expand Down Expand Up @@ -53,7 +54,6 @@ def __init__(
framer: FramerType,
timeout: float = 3,
retries: int = 3,
retry_on_empty: bool = False,
broadcast_enable: bool = False,
reconnect_delay: float = 0.1,
reconnect_delay_max: float = 300,
Expand Down Expand Up @@ -81,8 +81,6 @@ def __init__(
stopbits=kwargs.get("stopbits", None),
handle_local_echo=kwargs.get("handle_local_echo", False),
),
retries,
retry_on_empty,
on_connect_callback,
)
self.no_resend_on_retry = no_resend_on_retry
Expand Down Expand Up @@ -143,7 +141,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest | None = None):
def execute(self, request: ModbusRequest):
"""Execute request and get response (call **sync/async**).

:param request: The request to process
Expand All @@ -165,7 +163,7 @@ async def async_execute(self, request) -> ModbusResponse:
count = 0
while count <= self.retries:
async with self._lock:
req = self.build_response(request.transaction_id)
req = self.build_response(request)
if not count or not self.no_resend_on_retry:
self.ctx.framer.resetFrame()
self.ctx.send(packet)
Expand All @@ -187,25 +185,17 @@ async def async_execute(self, request) -> ModbusResponse:

return resp # type: ignore[return-value]

def build_response(self, tid):
def build_response(self, request: ModbusRequest):
"""Return a deferred response for the current request."""
my_future: asyncio.Future = asyncio.Future()
request.fut = my_future
if not self.ctx.transport:
if not my_future.done():
my_future.set_exception(ConnectionException("Client is not connected"))
else:
self.ctx.transaction.addTransaction(my_future, tid)
self.ctx.transaction.addTransaction(request)
return my_future

# ----------------------------------------------------------------------- #
# Internal methods
# ----------------------------------------------------------------------- #
def recv(self, size):
"""Receive data.

:meta private:
"""

# ----------------------------------------------------------------------- #
# The magic methods
# ----------------------------------------------------------------------- #
Expand Down Expand Up @@ -309,10 +299,10 @@ def __init__(
self.slaves: list[int] = []

# Common variables.
self.framer = FRAMER_NAME_TO_CLASS.get(
self.framer: ModbusFramer = FRAMER_NAME_TO_CLASS.get(
framer, cast(type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = ModbusTransactionManager(
self.transaction = SyncModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty, **kwargs
)
self.reconnect_delay_current = self.params.reconnect_delay or 0
Expand Down Expand Up @@ -346,7 +336,7 @@ def idle_time(self) -> float:
return 0
return self.last_frame_end + self.silent_interval

def execute(self, request: ModbusRequest | None = None) -> ModbusResponse:
def execute(self, request: ModbusRequest) -> ModbusResponse:
"""Execute request and get response (call **sync/async**).

:param request: The request to process
Expand All @@ -360,22 +350,28 @@ def execute(self, request: ModbusRequest | None = None) -> ModbusResponse:
# ----------------------------------------------------------------------- #
# Internal methods
# ----------------------------------------------------------------------- #
def send(self, request):
def _start_send(self):
"""Send request.

:meta private:
"""
if self.state != ModbusTransactionState.RETRYING:
Log.debug('New Transaction state "SENDING"')
self.state = ModbusTransactionState.SENDING
return request

def recv(self, size):
@abstractmethod
def send(self, request: bytes) -> int:
"""Send request.

:meta private:
"""

@abstractmethod
def recv(self, size: int | None) -> bytes:
"""Receive data.

:meta private:
"""
return size

@classmethod
def get_address_family(cls, address):
Expand Down
11 changes: 4 additions & 7 deletions pymodbus/client/modbusclientprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def __init__(
self,
framer: FramerType,
params: CommParams,
retries: int,
retry_on_empty: bool,
on_connect_callback: Callable[[bool], None] | None = None,
) -> None:
"""Initialize a client instance."""
Expand All @@ -45,17 +43,16 @@ def __init__(
self.framer = FRAMER_NAME_TO_CLASS.get(
framer, cast(type[ModbusFramer], framer)
)(ClientDecoder(), self)
self.transaction = ModbusTransactionManager(
self, retries=retries, retry_on_empty=retry_on_empty
)
self.transaction = ModbusTransactionManager()

def _handle_response(self, reply, **_kwargs):
"""Handle the processed response and link to correct deferred."""
if reply is not None:
tid = reply.transaction_id
if handler := self.transaction.getTransaction(tid):
if not handler.done():
handler.set_result(reply)
reply.request = handler
if not handler.fut.done():
handler.fut.set_result(reply)
else:
Log.debug("Unrequested message: {}", reply, ":str")

Expand Down
16 changes: 8 additions & 8 deletions pymodbus/client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def connected(self):
"""Connect internal."""
return self.connect()

def connect(self):
def connect(self) -> bool:
"""Connect to the modbus serial server."""
if self.socket:
return True
Expand Down Expand Up @@ -227,25 +227,26 @@ def _in_waiting(self):
"""Return waiting bytes."""
return getattr(self.socket, "in_waiting") if hasattr(self.socket, "in_waiting") else getattr(self.socket, "inWaiting")()

def send(self, request):
def send(self, request: bytes) -> int:
"""Send data on the underlying socket.

If receive buffer still holds some data then flush it.

Sleep if last send finished less than 3.5 character times ago.
"""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
if waitingbytes := self._in_waiting():
result = self.socket.read(waitingbytes)
Log.warning("Cleanup recv buffer before send: {}", result, ":hex")
size = self.socket.write(request)
if (size := self.socket.write(request)) is None:
size = 0
return size
return 0

def _wait_for_data(self):
def _wait_for_data(self) -> int:
"""Wait for data."""
size = 0
more_data = False
Expand All @@ -264,9 +265,8 @@ def _wait_for_data(self):
time.sleep(self._recv_interval)
return size

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor."""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))
if size is None:
Expand All @@ -276,7 +276,7 @@ def recv(self, size):
result = self.socket.read(size)
return result

def is_socket_open(self):
def is_socket_open(self) -> bool:
"""Check if socket is open."""
if self.socket:
return self.socket.is_open
Expand Down
9 changes: 4 additions & 5 deletions pymodbus/client/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,15 @@ def close(self):

def send(self, request):
"""Send data on the underlying socket."""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
return self.socket.send(request)
return 0

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor."""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))

Expand Down Expand Up @@ -241,7 +240,7 @@ def recv(self, size):

return b"".join(data)

def _handle_abrupt_socket_close(self, size, data, duration):
def _handle_abrupt_socket_close(self, size: int | None, data: list[bytes], duration: float) -> bytes:
"""Handle unexpected socket close by remote end.

Intended to be invoked after determining that the remote end
Expand Down Expand Up @@ -271,7 +270,7 @@ def _handle_abrupt_socket_close(self, size, data, duration):
msg += " without response from slave before it closed connection"
raise ConnectionException(msg)

def is_socket_open(self):
def is_socket_open(self) -> bool:
"""Check if socket is open."""
return self.socket is not None

Expand Down
9 changes: 5 additions & 4 deletions pymodbus/client/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ def close(self):
"""
self.socket = None

def send(self, request):
def send(self, request: bytes) -> int:
"""Send data on the underlying socket.

:meta private:
"""
super().send(request)
super()._start_send()
if not self.socket:
raise ConnectionException(str(self))
if request:
Expand All @@ -183,14 +183,15 @@ def send(self, request):
)
return 0

def recv(self, size):
def recv(self, size: int | None) -> bytes:
"""Read data from the underlying descriptor.

:meta private:
"""
super().recv(size)
if not self.socket:
raise ConnectionException(str(self))
if size is None:
size = 0
return self.socket.recvfrom(size)[0]

def is_socket_open(self):
Expand Down
26 changes: 17 additions & 9 deletions pymodbus/framer/old_framer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
from __future__ import annotations

import time
from typing import Any
from typing import TYPE_CHECKING, Any

from pymodbus.factory import ClientDecoder, ServerDecoder
from pymodbus.framer.base import FramerBase
from pymodbus.logging import Log
from pymodbus.pdu import ModbusRequest


if TYPE_CHECKING:
from pymodbus.client.base import ModbusBaseSyncClient

# Unit ID, Function Code
BYTE_ORDER = ">"
FRAME_HEADER = "BB"
Expand All @@ -29,24 +33,28 @@ class ModbusFramer:
def __init__(
self,
decoder: ClientDecoder | ServerDecoder,
client,
client: ModbusBaseSyncClient,
) -> None:
"""Initialize a new instance of the framer.

:param decoder: The decoder implementation to use
"""
self.decoder = decoder
self.client = client
self._header: dict[str, Any] = {
self._header: dict[str, Any]
self._reset_header()
self._buffer = b""
self.message_handler: FramerBase

def _reset_header(self) -> None:
self._header = {
"lrc": "0000",
"len": 0,
"uid": 0x00,
"tid": 0,
"pid": 0,
"crc": b"\x00\x00",
}
self._buffer = b""
self.message_handler: FramerBase

def _validate_slave_id(self, slaves: list, single: bool) -> bool:
"""Validate if the received data is valid for the client.
Expand All @@ -63,7 +71,7 @@ def _validate_slave_id(self, slaves: list, single: bool) -> bool:
return True
return self._header["uid"] in slaves

def sendPacket(self, message):
def sendPacket(self, message: bytes):
"""Send packets on the bus.

With 3.5char delay between frames
Expand All @@ -72,7 +80,7 @@ def sendPacket(self, message):
"""
return self.client.send(message)

def recvPacket(self, size):
def recvPacket(self, size: int) -> bytes:
"""Receive packet from the bus.

With specified len
Expand Down Expand Up @@ -117,7 +125,7 @@ def populateResult(self, result):
result.transaction_id = self._header.get("tid", 0)
result.protocol_id = self._header.get("pid", 0)

def processIncomingPacket(self, data, callback, slave, **kwargs):
def processIncomingPacket(self, data: bytes, callback, slave, **kwargs):
"""Process new packet pattern.

This takes in a new request packet, adds it to the current
Expand Down Expand Up @@ -150,7 +158,7 @@ def frameProcessIncomingPacket(
) -> None:
"""Process new packet pattern."""

def buildPacket(self, message) -> bytes:
def buildPacket(self, message: ModbusRequest) -> bytes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be ModbusRequest | ModbusResponse

"""Create a ready to send modbus packet.

:param message: The populated request/response to send
Expand Down
Loading