Skip to content

Commit

Permalink
Improve types for REPL (#2007)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrudd2 authored Feb 12, 2024
1 parent 9a53351 commit b34eb46
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 49 deletions.
4 changes: 2 additions & 2 deletions pymodbus/repl/client/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, name, signature, doc, slave=False):
self._params = signature.parameters
self.args = self.create_completion()
else:
self._params = ""
self._params = {}

if self.name.startswith("client.") and slave:
self.args.update(**DEFAULT_KWARGS)
Expand Down Expand Up @@ -232,7 +232,7 @@ class Result:
"""Represent result command."""

function_code: int | None = None
data: dict[int, Any] | None = None
data: dict[str, Any] = {}

def __init__(self, result):
"""Initialize.
Expand Down
20 changes: 2 additions & 18 deletions pymodbus/repl/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
| | \___ / Y ( <_> ) /_/ | | | \ ___/| |_> > |__
|____| / ____\____|__ /\____/\____ | /\ |____|_ /\___ > __/|____/
\/ \/ \/ \/ \/ \/|__|
v1.3.0 - {pymodbus_version}
v1.3.1 - {pymodbus_version}
----------------------------------------------------------------------------
"""

Expand All @@ -64,16 +64,6 @@ def bottom_toolbar():
)


class CaseInsenstiveChoice(click.Choice):
"""Do case Insensitive choice for click commands and options."""

def convert(self, value, param, ctx):
"""Convert args to uppercase for evaluation."""
if value is None:
return None
return super().convert(value.strip().upper(), param, ctx)


class NumericChoice(click.Choice):
"""Do numeric choice for click arguments and options."""

Expand All @@ -88,12 +78,6 @@ def convert(self, value, param, ctx):
if value in self.choices:
return self.typ(value)

if ctx is not None and ctx.token_normalize_func is not None:
value = ctx.token_normalize_func(value)
for choice in self.casted_choices: # pylint: disable=no-member
if ctx.token_normalize_func(choice) == value:
return choice

self.fail(
f"invalid choice: {value}. (choose from {', '.join(self.choices)})",
param,
Expand Down Expand Up @@ -348,7 +332,7 @@ def tcp(ctx, host, port, framer):
"PARITY_NONE, PARITY_EVEN, PARITY_ODD PARITY_MARK, "
'PARITY_SPACE. Default to "N"',
default="N",
type=CaseInsenstiveChoice(["N", "E", "O", "M", "S"]),
type=click.Choice(["N", "E", "O", "M", "S"], case_sensitive=False),
)
@click.option(
"--stopbits",
Expand Down
72 changes: 43 additions & 29 deletions pymodbus/repl/client/mclient.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Modbus Clients to be used with REPL."""
# pylint: disable=missing-type-doc
import functools
from typing import TYPE_CHECKING

from pymodbus.client import ModbusSerialClient as _ModbusSerialClient
from pymodbus.client import ModbusTcpClient as _ModbusTcpClient
from pymodbus.client.base import ModbusBaseSyncClient as _ModbusBaseSyncClient
from pymodbus.diag_message import (
ChangeAsciiInputDelimiterRequest,
ClearCountersRequest,
ClearOverrunCountRequest,
DiagnosticStatusResponse,
ForceListenOnlyModeRequest,
GetClearModbusPlusRequest,
RestartCommunicationsOptionRequest,
Expand All @@ -24,14 +27,22 @@
ReturnSlaveNoResponseCountRequest,
)
from pymodbus.exceptions import ModbusIOException
from pymodbus.mei_message import ReadDeviceInformationRequest
from pymodbus.mei_message import (
ReadDeviceInformationRequest,
ReadDeviceInformationResponse,
)
from pymodbus.other_message import (
GetCommEventCounterRequest,
GetCommEventCounterResponse,
GetCommEventLogRequest,
GetCommEventLogResponse,
ReadExceptionStatusRequest,
ReadExceptionStatusResponse,
ReportSlaveIdRequest,
ReportSlaveIdResponse,
)
from pymodbus.pdu import ExceptionResponse, ModbusExceptions
from pymodbus.register_write_message import MaskWriteRegisterResponse


def make_response_dict(resp):
Expand All @@ -46,7 +57,7 @@ def make_response_dict(resp):
return resp_dict


def handle_brodcast(func):
def handle_broadcast(func):
"""Handle broadcast."""

@functools.wraps(func)
Expand All @@ -63,8 +74,11 @@ def _wrapper(*args, **kwargs):

return _wrapper


class ExtendedRequestSupport: # pylint: disable=(too-many-public-methods
if TYPE_CHECKING:
_Base = _ModbusBaseSyncClient
else:
_Base = object
class ExtendedRequestSupport(_Base): # pylint: disable=(too-many-public-methods
"""Extended request support."""

@staticmethod
Expand Down Expand Up @@ -98,7 +112,7 @@ def read_coils(self, address, count=1, slave=0, **kwargs):
:param kwargs:
:returns: List of register values
"""
resp = super().read_coils( # pylint: disable=no-member
resp = super().read_coils(
address, count, slave, **kwargs
)
if not resp.isError():
Expand All @@ -114,14 +128,14 @@ def read_discrete_inputs(self, address, count=1, slave=0, **kwargs):
:param kwargs:
:return: List of bits
"""
resp = super().read_discrete_inputs( # pylint: disable=no-member
resp = super().read_discrete_inputs(
address, count, slave, **kwargs
)
if not resp.isError():
return {"function_code": resp.function_code, "bits": resp.bits}
return ExtendedRequestSupport._process_exception(resp, slave=slave)

@handle_brodcast
@handle_broadcast
def write_coil(self, address, value, slave=0, **kwargs):
"""Write `value` to coil at `address`.
Expand All @@ -131,12 +145,12 @@ def write_coil(self, address, value, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().write_coil( # pylint: disable=no-member
resp = super().write_coil(
address, value, slave, **kwargs
)
return resp

@handle_brodcast
@handle_broadcast
def write_coils(self, address, values, slave=0, **kwargs):
"""Write `value` to coil at `address`.
Expand All @@ -146,12 +160,12 @@ def write_coils(self, address, values, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().write_coils( # pylint: disable=no-member
resp = super().write_coils(
address, values, slave, **kwargs
)
return resp

@handle_brodcast
@handle_broadcast
def write_register(self, address, value, slave=0, **kwargs):
"""Write `value` to register at `address`.
Expand All @@ -161,12 +175,12 @@ def write_register(self, address, value, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().write_register( # pylint: disable=no-member
resp = super().write_register(
address, value, slave, **kwargs
)
return resp

@handle_brodcast
@handle_broadcast
def write_registers(self, address, values, slave=0, **kwargs):
"""Write list of `values` to registers starting at `address`.
Expand All @@ -176,7 +190,7 @@ def write_registers(self, address, values, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().write_registers( # pylint: disable=no-member
resp = super().write_registers(
address, values, slave, **kwargs
)
return resp
Expand All @@ -190,7 +204,7 @@ def read_holding_registers(self, address, count=1, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().read_holding_registers( # pylint: disable=no-member
resp = super().read_holding_registers(
address, count, slave, **kwargs
)
if not resp.isError():
Expand All @@ -206,7 +220,7 @@ def read_input_registers(self, address, count=1, slave=0, **kwargs):
:param kwargs:
:return:
"""
resp = super().read_input_registers( # pylint: disable=no-member
resp = super().read_input_registers(
address, count, slave, **kwargs
)
if not resp.isError():
Expand Down Expand Up @@ -235,7 +249,7 @@ def readwrite_registers(
:param kwargs:
:return:
"""
resp = super().readwrite_registers( # pylint: disable=no-member
resp = super().readwrite_registers(
read_address=read_address,
read_count=read_count,
write_address=write_address,
Expand Down Expand Up @@ -264,7 +278,7 @@ def mask_write_register(
:param kwargs:
:return:
"""
resp = super().mask_write_register( # pylint: disable=no-member
resp: MaskWriteRegisterResponse = super().mask_write_register(
address=address, and_mask=and_mask, or_mask=or_mask, slave=slave, **kwargs
)
if not resp.isError():
Expand All @@ -285,7 +299,7 @@ def read_device_information(self, read_code=None, object_id=0x00, **kwargs):
:return:
"""
request = ReadDeviceInformationRequest(read_code, object_id, **kwargs)
resp = self.execute(request) # pylint: disable=no-member
resp: ReadDeviceInformationResponse = self.execute(request)
if not resp.isError():
return {
"function_code": resp.function_code,
Expand All @@ -306,7 +320,7 @@ def report_slave_id(self, slave=0, **kwargs):
:return:
"""
request = ReportSlaveIdRequest(slave, **kwargs)
resp = self.execute(request) # pylint: disable=no-member
resp: ReportSlaveIdResponse = self.execute(request)
if not resp.isError():
return {
"function_code": resp.function_code,
Expand All @@ -324,7 +338,7 @@ def read_exception_status(self, slave=0, **kwargs):
:return:
"""
request = ReadExceptionStatusRequest(slave, **kwargs)
resp = self.execute(request) # pylint: disable=no-member
resp: ReadExceptionStatusResponse = self.execute(request)
if not resp.isError():
return {"function_code": resp.function_code, "status": resp.status}
return ExtendedRequestSupport._process_exception(resp, slave=request.slave_id)
Expand All @@ -338,7 +352,7 @@ def get_com_event_counter(self, **kwargs):
:return:
"""
request = GetCommEventCounterRequest(**kwargs)
resp = self.execute(request) # pylint: disable=no-member
resp: GetCommEventCounterResponse = self.execute(request)
if not resp.isError():
return {
"function_code": resp.function_code,
Expand All @@ -357,7 +371,7 @@ def get_com_event_log(self, **kwargs):
:return:
"""
request = GetCommEventLogRequest(**kwargs)
resp = self.execute(request) # pylint: disable=no-member
resp: GetCommEventLogResponse = self.execute(request)
if not resp.isError():
return {
"function_code": resp.function_code,
Expand All @@ -370,7 +384,7 @@ def get_com_event_log(self, **kwargs):

def _execute_diagnostic_request(self, request):
"""Execute diagnostic request."""
resp = self.execute(request) # pylint: disable=no-member
resp: DiagnosticStatusResponse = self.execute(request)
if not resp.isError():
return {
"function code": resp.function_code,
Expand All @@ -379,7 +393,7 @@ def _execute_diagnostic_request(self, request):
}
return ExtendedRequestSupport._process_exception(resp, slave=request.slave_id)

def return_query_data(self, message=0, **kwargs):
def return_query_data(self, message=b"\x00", **kwargs):
"""Loop back data sent in response.
:param message: Message to be looped back
Expand Down Expand Up @@ -582,14 +596,14 @@ def get_stopbits(self):
:return: Current Stop bits
"""
return self.params.stopbits
return self.comm_params.stopbits

def set_stopbits(self, value):
"""Set stop bit.
:param value: Possible values (1, 1.5, 2)
"""
self.params.stopbits = float(value)
self.comm_params.stopbits = float(value)
if self.is_socket_open():
self.close()

Expand All @@ -615,14 +629,14 @@ def get_parity(self):
:return: Current parity setting
"""
return self.params.parity
return self.comm_params.parity

def set_parity(self, value):
"""Set parity Setter.
:param value: Possible values ("N", "E", "O", "M", "S")
"""
self.params.parity = value
self.comm_params.parity = value
if self.is_socket_open():
self.close()

Expand Down

0 comments on commit b34eb46

Please sign in to comment.