Skip to content

Commit

Permalink
Add background parameter to servers. (#2529)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen authored Dec 28, 2024
1 parent 4eccc75 commit f3762bb
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 128 deletions.
5 changes: 4 additions & 1 deletion pymodbus/client/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base for all clients."""
from __future__ import annotations

import asyncio
from abc import abstractmethod
from collections.abc import Awaitable, Callable

Expand Down Expand Up @@ -57,7 +58,9 @@ async def connect(self) -> bool:
self.ctx.comm_params.host,
self.ctx.comm_params.port,
)
return await self.ctx.connect()
rc = await self.ctx.connect()
await asyncio.sleep(0.1)
return rc

def register(self, custom_response_class: type[ModbusPDU]) -> None:
"""Register a custom response class with the decoder (call **sync**).
Expand Down
6 changes: 6 additions & 0 deletions pymodbus/pdu/pdu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import struct
from abc import abstractmethod

from pymodbus.datastore import ModbusSlaveContext
from pymodbus.exceptions import NotImplementedException
from pymodbus.logging import Log

Expand Down Expand Up @@ -79,6 +80,11 @@ def encode(self) -> bytes:
def decode(self, data: bytes) -> None:
"""Decode data part of the message."""

async def update_datastore(self, context: ModbusSlaveContext) -> ModbusPDU:
"""Run request against a datastore."""
_ = context
return ExceptionResponse(0, 0)

@classmethod
def calculateRtuFrameSize(cls, data: bytes) -> int:
"""Calculate the size of a PDU."""
Expand Down
15 changes: 9 additions & 6 deletions pymodbus/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
from collections.abc import Callable
from contextlib import suppress

from pymodbus.datastore import ModbusServerContext
from pymodbus.device import ModbusControlBlock, ModbusDeviceIdentification
Expand Down Expand Up @@ -67,25 +68,27 @@ async def shutdown(self):
self.serving.set_result(True)
self.close()

async def serve_forever(self):
async def serve_forever(self, *, background: bool = False):
"""Start endless loop."""
if self.transport:
raise RuntimeError(
"Can't call serve_forever on an already running server object"
)
await self.listen()
Log.info("Server listening.")
await self.serving
Log.info("Server graceful shutdown.")
if not background:
with suppress(asyncio.exceptions.CancelledError):
await self.serving
Log.info("Server graceful shutdown.")

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""
raise RuntimeError("callback_new_connection should never be called")

def callback_disconnected(self, exc: Exception | None) -> None:
"""Call when connection is lost."""
Log.debug("callback_disconnected called: {}", exc)
raise RuntimeError("callback_disconnected should never be called")

def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
Log.debug("callback_data called: {} addr={}", data, ":hex", addr)
return 0
raise RuntimeError("callback_data should never be called")
106 changes: 33 additions & 73 deletions pymodbus/server/requesthandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import traceback

from pymodbus.exceptions import NoSuchSlaveException
from pymodbus.exceptions import ModbusIOException, NoSuchSlaveException
from pymodbus.logging import Log
from pymodbus.pdu.pdu import ExceptionResponse
from pymodbus.transaction import TransactionManager
Expand All @@ -29,9 +29,6 @@ def __init__(self, owner):
self.server = owner
self.framer = self.server.framer(self.server.decoder)
self.running = False
self.handler_task = None # coroutine to be run on asyncio loop
self.databuffer = b''
self.loop = asyncio.get_running_loop()
super().__init__(
params,
self.framer,
Expand All @@ -44,8 +41,7 @@ def __init__(self, owner):

def callback_new_connection(self) -> ModbusProtocol:
"""Call when listener receive new connection request."""
Log.debug("callback_new_connection called")
return ServerRequestHandler(self)
raise RuntimeError("callback_new_connection should never be called")

def callback_connected(self) -> None:
"""Call when connection is succcesfull."""
Expand All @@ -54,27 +50,11 @@ def callback_connected(self) -> None:
if self.server.broadcast_enable:
if 0 not in slaves:
slaves.append(0)
try:
self.running = True

# schedule the connection handler on the event loop
self.handler_task = asyncio.create_task(self.handle())
self.handler_task.set_name("server connection handler")
except Exception as exc: # pylint: disable=broad-except
Log.error(
"Server callback_connected exception: {}; {}",
exc,
traceback.format_exc(),
)

def callback_disconnected(self, call_exc: Exception | None) -> None:
"""Call when connection is lost."""
super().callback_disconnected(call_exc)
try:
if self.handler_task:
self.handler_task.cancel()
if hasattr(self.server, "on_connection_lost"):
self.server.on_connection_lost()
if call_exc is None:
Log.debug(
"Handler for stream [{}] has been canceled", self.comm_params.comm_name
Expand All @@ -93,66 +73,46 @@ def callback_disconnected(self, call_exc: Exception | None) -> None:
traceback.format_exc(),
)

async def handle(self) -> None:
"""Coroutine which represents a single master <=> slave conversation.
Once the client connection is established, the data chunks will be
fed to this coroutine via the asyncio.Queue object which is fed by
the ServerRequestHandler class's callback Future.
This callback future gets data from either asyncio.BaseProtocol.data_received
or asyncio.DatagramProtocol.datagram_received.
def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
"""Handle received data."""
try:
used_len = super().callback_data(data, addr)
except ModbusIOException:
response = ExceptionResponse(
40,
exception_code=ExceptionResponse.ILLEGAL_FUNCTION
)
self.server_send(response, 0)
return(len(data))
if self.last_pdu:
if self.is_server:
self.loop.call_soon(self.handle_later)
else:
self.response_future.set_result(True)
return used_len

This function will execute without blocking in the while-loop and
yield to the asyncio event loop when the frame is exhausted.
As a result, multiple clients can be interleaved without any
interference between them.
"""
while self.running:
try:
pdu, *addr, exc = await self.server_execute()
if exc:
pdu = ExceptionResponse(
40,
exception_code=ExceptionResponse.ILLEGAL_FUNCTION
)
self.server_send(pdu, 0)
continue
await self.server_async_execute(pdu, *addr)
except asyncio.CancelledError:
# catch and ignore cancellation errors
if self.running:
Log.debug(
"Handler for stream [{}] has been canceled", self.comm_params.comm_name
)
self.running = False
except Exception as exc: # pylint: disable=broad-except
# force TCP socket termination as framer
# should handle application layer errors
Log.error(
'Unknown exception "{}" on stream {} forcing disconnect',
exc,
self.comm_params.comm_name,
)
self.close()
self.callback_disconnected(exc)
def handle_later(self):
"""Change sync (async not allowed in call_soon) to async."""
asyncio.run_coroutine_threadsafe(self.handle_request(), self.loop)

async def server_async_execute(self, request, *addr):
async def handle_request(self):
"""Handle request."""
broadcast = False
if not self.last_pdu:
return
try:
if self.server.broadcast_enable and not request.dev_id:
if self.server.broadcast_enable and not self.last_pdu.dev_id:
broadcast = True
# if broadcasting then execute on all slave contexts,
# note response will be ignored
for dev_id in self.server.context.slaves():
response = await request.update_datastore(self.server.context[dev_id])
response = await self.last_pdu.update_datastore(self.server.context[dev_id])
else:
context = self.server.context[request.dev_id]
response = await request.update_datastore(context)
context = self.server.context[self.last_pdu.dev_id]
response = await self.last_pdu.update_datastore(context)

except NoSuchSlaveException:
Log.error("requested slave does not exist: {}", request.dev_id)
Log.error("requested slave does not exist: {}", self.last_pdu.dev_id)
if self.server.ignore_missing_slaves:
return # the client will simply timeout waiting for a response
response = ExceptionResponse(0x00, ExceptionResponse.GATEWAY_NO_RESPONSE)
Expand All @@ -165,9 +125,9 @@ async def server_async_execute(self, request, *addr):
response = ExceptionResponse(0x00, ExceptionResponse.SLAVE_FAILURE)
# no response when broadcasting
if not broadcast:
response.transaction_id = request.transaction_id
response.dev_id = request.dev_id
self.server_send(response, *addr)
response.transaction_id = self.last_pdu.transaction_id
response.dev_id = self.last_pdu.dev_id
self.server_send(response, self.last_addr)

def server_send(self, pdu, addr):
"""Send message."""
Expand Down
71 changes: 45 additions & 26 deletions pymodbus/server/startstop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import asyncio
import os
from contextlib import suppress

from pymodbus.datastore import ModbusServerContext
from pymodbus.pdu import ModbusPDU
Expand All @@ -27,13 +26,13 @@ async def StartAsyncTcpServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusTcpServer
.. tip::
Only handles a single server !
Use ModbusTcpServer to allow multiple servers in one app.
"""
server = ModbusTcpServer(context, **kwargs)
if custom_functions:
for func in custom_functions:
server.decoder.register(func)
with suppress(asyncio.exceptions.CancelledError):
await server.serve_forever()
await ModbusTcpServer(context, custom_pdu=custom_functions, **kwargs).serve_forever()


def StartTcpServer( # pylint: disable=invalid-name
Expand All @@ -46,8 +45,13 @@ def StartTcpServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusTcpServer
.. tip::
Only handles a single server !
Use ModbusTcpServer to allow multiple servers in one app.
"""
return asyncio.run(StartAsyncTcpServer(context, custom_functions=custom_functions, **kwargs))
asyncio.run(StartAsyncTcpServer(context, custom_functions=custom_functions, **kwargs))


async def StartAsyncTlsServer( # pylint: disable=invalid-name
Expand All @@ -60,13 +64,13 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusTlsServer
.. tip::
Only handles a single server !
Use ModbusTlsServer to allow multiple servers in one app.
"""
server = ModbusTlsServer(context, **kwargs)
if custom_functions:
for func in custom_functions:
server.decoder.register(func)
with suppress(asyncio.exceptions.CancelledError):
await server.serve_forever()
await ModbusTlsServer(context, custom_pdu=custom_functions, **kwargs).serve_forever()


def StartTlsServer( # pylint: disable=invalid-name
Expand All @@ -79,6 +83,11 @@ def StartTlsServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusTlsServer
.. tip::
Only handles a single server !
Use ModbusTlsServer to allow multiple servers in one app.
"""
asyncio.run(StartAsyncTlsServer(context, custom_functions=custom_functions, **kwargs))

Expand All @@ -93,13 +102,13 @@ async def StartAsyncUdpServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusUdpServer
.. tip::
Only handles a single server !
Use ModbusUdpServer to allow multiple servers in one app.
"""
server = ModbusUdpServer(context, **kwargs)
if custom_functions:
for func in custom_functions:
server.decoder.register(func)
with suppress(asyncio.exceptions.CancelledError):
await server.serve_forever()
await ModbusUdpServer(context, custom_pdu=custom_functions, **kwargs).serve_forever()


def StartUdpServer( # pylint: disable=invalid-name
Expand All @@ -112,6 +121,11 @@ def StartUdpServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusUdpServer
.. tip::
Only handles a single server !
Use ModbusUdpServer to allow multiple servers in one app.
"""
asyncio.run(StartAsyncUdpServer(context, custom_functions=custom_functions, **kwargs))

Expand All @@ -126,13 +140,13 @@ async def StartAsyncSerialServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusSerialServer
.. tip::
Only handles a single server !
Use ModbusSerialServer to allow multiple servers in one app.
"""
server = ModbusSerialServer(context, **kwargs)
if custom_functions:
for func in custom_functions:
server.decoder.register(func)
with suppress(asyncio.exceptions.CancelledError):
await server.serve_forever()
await ModbusSerialServer(context, custom_pdu=custom_functions, **kwargs).serve_forever()


def StartSerialServer( # pylint: disable=invalid-name
Expand All @@ -145,6 +159,11 @@ def StartSerialServer( # pylint: disable=invalid-name
:parameter context: Datastore object
:parameter custom_functions: optional list of custom PDU objects
:parameter kwargs: for parameter explanation see ModbusSerialServer
.. tip::
Only handles a single server !
Use ModbusSerialServer to allow multiple servers in one app.
"""
asyncio.run(StartAsyncSerialServer(context, custom_functions=custom_functions, **kwargs))

Expand Down
Loading

0 comments on commit f3762bb

Please sign in to comment.