.."""
- pre = ""
- if self.pre:
- pre = f".{self.pre}"
- return f"{self.major}.{self.minor}.{self.micro}{pre}"
-
- def __str__(self) -> str:
- """Return a string representation of the object.
-
- :returns: A string representation of this object
- """
- return f"[{self.package}, version {self.short()}]"
-
-
-version = Version("pymodbus", 3, 2, 2, "")
diff --git a/requirements.txt b/requirements.txt
index 00885d88f..94748cf04 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,6 @@
# Required packages.
# -------------------------------------------------------------------
# install:required
-setuptools<66.0.0
# -------------------------------------------------------------------
# optional packages.
@@ -33,10 +32,6 @@ click>=8.0.0
# install:serial
pyserial>=3.5
-# install:datastore
-redis>=2.10.6
-sqlalchemy>=1.1.15
-
# -------------------------------------------------------------------
# documentation, everything needed to generate documentation.
@@ -51,24 +46,19 @@ sphinx-rtd-theme==1.1.1
# development, everything needed to develop/test/check.
# -------------------------------------------------------------------
# install:development
-bandit==1.7.4
codespell==2.2.2
coverage==7.1.0
-flake8==6.0.0
-flake8-docstrings==1.7.0
-flake8-noqa==1.3.0
-flake8-comprehensions==3.10.1
-mypy==1.0.1
+mypy==1.3.0
pre-commit==3.1.1
pyflakes==3.0.1
pydocstyle==6.3.0
pycodestyle==2.10.0
-pylint==2.15.10
+pylint==2.17.2
pytest==7.2.1
pytest-asyncio==0.20.3
pytest-cov==4.0.0
pytest-timeout==2.1.0
pytest-xdist==3.1.0
-types-redis
+ruff==0.0.261
types-Pygments
types-pyserial
diff --git a/ruff.toml b/ruff.toml
new file mode 100644
index 000000000..75cd36afe
--- /dev/null
+++ b/ruff.toml
@@ -0,0 +1,52 @@
+target-version="py38"
+exclude = [
+ "pymodbus/transport/serial_asyncio",
+ "venv",
+ ".venv",
+ ".git",
+ "build",
+]
+ignore = [
+ "D202", # No blank lines allowed after function docstring (to work with black)
+ "D400", # docstrings ending in period
+ "E501", # line too long
+ "E731", # lambda expressions
+ "PT019", # Bug: https://github.com/m-burst/flake8-pytest-style/issues/202
+ "S101", # Use of `assert`
+ "S311", # PRNG for cryptography
+ "S104", # binding on all interfaces
+]
+line-length = 120
+select = [
+ "B007", # Loop control variable {name} not used within loop body
+ "B014", # Exception handler with duplicate exception
+ "C", # complexity
+ "D", # docstrings
+ "E", # pycodestyle errors
+ "F", # pyflakes
+ "I", # isort
+ "PGH", # pygrep-hooks
+ "PLC", # pylint
+ "PT", # flake8-pytest-style
+ "RUF", # ruff builtins
+ "S", # bandit
+ "SIM105", # flake8-simplify
+ "SIM117", #
+ "SIM118", #
+ "SIM201", #
+ "SIM212", #
+ "SIM300", #
+ "SIM401", #
+ "UP", # pyupgrade
+ "W", # pycodestyle warnings
+ # "TRY", # tryceratops
+ "TRY004", # Prefer TypeError exception for invalid type
+]
+[pydocstyle]
+convention = "pep257"
+[isort]
+lines-after-imports = 2
+known-local-folder = [
+ "common",
+ "contrib",
+]
diff --git a/setup.cfg b/setup.cfg
index 06579c5ab..bd5a1e976 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -65,20 +65,13 @@ include = pymodbus*
[pylint.master]
-# Specify a configuration file.
-#rcfile=
-
-# Python code to execute.
-#init-hook=
-
-# Files or directories to be skipped.
-#ignore=CVS
# Add files or directories matching the regex patterns to the ignore-list.
ignore-paths=
- examples/v2.5.3,
- pymodbus/client/serial_asyncio,
+ pymodbus/transport/serial_asyncio,
doc
+
+# Files or directories matching the regular expression patterns are skipped.
ignore-patterns=^\.#
# Pickle collected data for later comparisons.
@@ -97,12 +90,12 @@ load-plugins=
pylint.extensions.emptystring,
pylint.extensions.eq_without_hash,
pylint.extensions.for_any_all,
- pylint.extensions.mccabe,
pylint.extensions.overlapping_exceptions,
pylint.extensions.private_import,
pylint.extensions.set_membership,
pylint.extensions.typing,
# NOT WANTED:
+# pylint.extensions.mccabe, (replaced by ruff)
# pylint.extensions.broad_try_clause,
# pylint.extensions.consider_ternary_expression,
# pylint.extensions.empty_comment,
@@ -112,32 +105,10 @@ load-plugins=
# Use multiple processes to speed up Pylint, 0 will auto-detect.
jobs=0
-# pylint would attempt to guess common misconfiguration.
-suggestion-mode=yes
-
-# Allow loading of arbitrary C extensions.
-unsafe-load-any-extension=no
-
-# package or module names from where C extensions may be loaded.
-extension-pkg-allow-list=
-
# Minimum supported python version
-py-version = 3.8.0
-
-# Amount of potential inferred values with a single object.
-limit-inference-results=100
-
-# Specify a score threshold to be exceeded before program exits with error.
-fail-under=10.0
-
-# Return non-zero exit code if any of these messages/categories are detected.
-fail-on=
-
-
+py-version = 3.8
[pylint.messages_control]
-# Only show warnings with the listed confidence levels.
-# confidence=
# Enable/Disable the message/report/category/checker with the given id(s).
enable=all
@@ -149,42 +120,16 @@ disable=
suppressed-message, # NOT wanted
[pylint.reports]
+
# Set the output format.
output-format=text
-# Tells whether to display a full report or only the messages
-reports=no
-
-# Python expression which should return a note less than 10.
-evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))
-
-# Template used to display messages.
-#msg-template=
-
-# Activate the evaluation score.
-score=yes
-
[pylint.logging]
-# Logging modules to check the string format.
-logging-modules=logging
# The type of string formatting that logging methods do.
logging-format-style=new
-[pylint.miscellaneous]
-# List of note tags/regular expression to take in consideration.
-notes=FIXME,XXX,TODO
-#notes-rgx=
-
[pylint.similarities]
-# Minimum lines number of a similarity.
-min-similarity-lines=4
-
-# Ignore comments when computing similarities.
-ignore-comments=yes
-
-# Ignore docstrings when computing similarities.
-ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
@@ -192,180 +137,29 @@ ignore-imports=no
# Signatures are removed from the similarity computation
ignore-signatures=no
-
[pylint.variables]
-# Tells whether we should check for unused import in __init__ files.
-init-import=no
# A regular expression matching the name of dummy variables
-#dummy-variables-rgx=_$|dummy
dummy-variables-rgx=
-# List of additional names supposed to be defined in builtins.
-additional-builtins=
-
-# List of strings which can identify a callback function by name.
-#callbacks=cb_,_cb
-callbacks=
-
-# Tells whether unused global variables should be treated as a violation.
-allow-global-unused-variables=yes
-
-# List of names allowed to shadow builtins
-allowed-redefined-builtins=
-
-# Argument names that match this expression will be ignored.
-ignored-argument-names=_.*
-
-# List of qualified module names which can have objects that can redefine
-# builtins.
-#redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
-redefining-builtins-modules=past.builtins,future.builtins,builtins,io
-
[pylint.format]
-# Maximum number of characters on a single line.
-max-line-length=100
-
-# Regexp for a line that is allowed to be longer than the limit.
-ignore-long-lines=^\s*(# )??$
-
-# Allow the body of an if to be on the same line.
-single-line-if-stmt=no
-
-# Allow the body of a class to be on the same line as the declaration
-single-line-class-stmt=no
# Maximum number of lines in a module
max-module-lines=2000
-# String used as indentation unit.
-indent-string=' '
-indent-after-paren=4
-
-# Expected format of line ending.
-#expected-line-ending-format=
-
[pylint.basic]
+
# Good variable names which should always be accepted.
-#good-names=i,j,k,run,_
good-names=i,j,k,rr,fc,rq,fd,x,_
-#good-names-rgxs=
-# Bad variable names which should always be refused, separated by a comma
-bad-names=foo,bar,baz,toto,tutu,tata
-bad-names-rgxs=
-
-# Colon-delimited sets of names that determine each other's naming style when
-# the name regexes allow several styles.
-name-group=
-
-# Include a hint for the correct naming format with invalid-name
-include-naming-hint=no
-
-# Naming style matching correct function names.
-function-naming-style=snake_case
-function-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming style matching correct variable names.
-variable-naming-style=snake_case
-variable-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming style matching correct constant names.
-const-naming-style=UPPER_CASE
-const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$
-
-# Naming style matching correct attribute names.
-attr-naming-style=snake_case
+# Regular expression matching correct attribute names.
attr-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-# Naming style matching correct argument names.
-argument-naming-style=snake_case
-argument-rgx=[a-z_][a-z0-9_]{2,30}$
-
-# Naming style matching correct class attribute names.
-class-attribute-naming-style=any
-class-attribute-rgx=([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$
-
-# Naming style matching correct class constant names.
-class-const-naming-style=UPPER_CASE
-#class-const-rgx=
-
-# Naming style matching correct inline iteration names.
-inlinevar-naming-style=any
-inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$
-
-# Naming style matching correct class names.
-class-naming-style=PascalCase
-class-rgx=[A-Z_][a-zA-Z0-9]+$
-
-# Naming style matching correct module names.
-module-naming-style=snake_case
-module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$
-
-# Naming style matching correct method names.
-method-naming-style=snake_case
+# Regular expression matching correct method names.
method-rgx=[a-z_][a-zA-Z0-9_]{2,}$
-# function or class names that do not require a docstring.
-no-docstring-rgx=__.*__
-
-# Minimum line length for functions/classes that require docstrings.
-docstring-min-length=-1
-
-# List of decorators that define properties, such as abc.abstractproperty.
-property-classes=abc.abstractproperty
-
-[pylint.typecheck]
-# Regex pattern to define which classes are considered mixins
-mixin-class-rgx=.*MixIn
-
-# List of module names for which member attributes should not be checked.
-ignored-modules=
-
-# List of class names for which member attributes should not be checked.
-ignored-classes=SQLObject, optparse.Values, thread._local, _thread._local
-
-# List of members which are set dynamically.
-generated-members=REQUEST,acl_users,aq_parent,argparse.Namespace
-
-# List of decorators that create context managers from functions.
-contextmanager-decorators=contextlib.contextmanager
-
-# Warn about missing members when the attribute can be None.
-ignore-none=yes
-
-# This flag controls whether pylint should warn about no-member
-ignore-on-opaque-inference=yes
-
-# Show a hint with possible names when a member name was not found.
-missing-member-hint=yes
-
-# The minimum edit distance a name should have.
-missing-member-hint-distance=1
-
-# The total number of similar names that should be taken in consideration.
-missing-member-max-choices=1
-
-[pylint.spelling]
-# Spelling dictionary name.
-spelling-dict=
-
-# List of comma separated words that should not be checked.
-spelling-ignore-words=
-
-# List of comma separated words that should be considered directives.
-spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,isort:skip,mypy:
-
-# A path to a file that contains private dictionary; one word per line.
-spelling-private-dict-file=
-
-# Tells whether to store unknown words.
-spelling-store-unknown-words=no
-
-# Limits count of emitted suggestions for spelling mistakes.
-max-spelling-suggestions=4
-
[pylint.design]
+
# Maximum number of arguments for function / method
max-args=10
@@ -381,125 +175,32 @@ max-branches=27
# Maximum number of statements in function / method body
max-statements=100
-# Maximum number of parents for a class (see R0901).
-max-parents=7
-
-# List of qualified class names to ignore when counting class parents.
-ignored-parents=
-
# Maximum number of attributes for a class (see R0902).
max-attributes=20
-# Minimum number of public methods for a class (see R0903).
-min-public-methods=2
-
# Maximum number of public methods for a class (see R0904).
max-public-methods=25
-# Maximum number of boolean expressions in an if statement (see R0916).
-max-bool-expr=5
-
-# Regular expressions of class ancestor names to ignore.
-exclude-too-few-public-methods=
-
[pylint.classes]
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods=__init__,__new__,setUp,__post_init__
-
-# List of valid names for the first argument in a class method.
-valid-classmethod-first-arg=cls
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
-# List of member names, which should be excluded from the protected access
-# warning.
-exclude-protected=_asdict,_fields,_replace,_source,_make
-
-# Warn about protected attribute access inside special methods
-check-protected-access-in-special-methods=no
-
[pylint.imports]
-# List of modules that can be imported at any level.
-#allow-any-import-level=
-allow-any-import-level=no
-
-# Allow wildcard imports from modules that define __all__.
-allow-wildcard-with-all=no
-
-# Analyse import fallback blocks.
-analyse-fallback-blocks=no
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,TERMIOS,Bastion,rexec
-# Create a graph of every dependencies.
-import-graph=
-ext-import-graph=
-int-import-graph=
-
-# import to recognize a module as a standard compatibility libraries.
-known-standard-library=
-
-# Force import order to recognize a module as part of a third party library.
-known-third-party=enchant
-
-# Couples of modules and preferred modules, separated by a comma.
-preferred-modules=
-
[pylint.exceptions]
-# Exceptions that will emit a warning when being caught.
-overgeneral-exceptions=Exception
-[pylint.typing]
-# app / library does need runtime introspection of type annotations.
-runtime-typing = true
+# Exceptions that will emit a warning when being caught.
+overgeneral-exceptions=builtins.Exception
[pylint.deprecated_builtins]
+
# List of builtins function names that should not be used.
bad-functions=map,input
-[pylint.refactoring]
-# Maximum number of nested blocks for function / method body
-max-nested-blocks=5
-
-# Complete name of functions that never returns.
-never-returning-functions=sys.exit,argparse.parse_error
-
-[pylint.string]
-# inconsistent-quotes generates a warning.
-check-quote-consistency=no
-
-# implicit-str-concat should generate a warning.
-check-str-concat-over-line-jumps=no
-
-[pylint.code_style]
-# Max line length for which to sill emit suggestions.
-#max-line-length-suggestions=
-
-[flake8]
-exclude = pymodbus/client/serial_asyncio, venv,.venv,.git,build,examples/v2.5.3
-doctests = True
-max-line-length = 120
-# To work with Black
-# D202 No blank lines allowed after function docstring
-# E203: Whitespace before ':'
-# E501: line too long
-# W503: Line break occurred before a binary operator
-# W504 line break after binary operator
-ignore =
- D202,
- E203,
- E501,
- W503,
- W504,
-
- D211,
- D400,
- E731,
- W503
-noqa-require-code = True
-
[mypy]
strict_optional = False
@@ -550,6 +251,7 @@ upload_dir = build/sphinx/html
testpaths = test
addopts = -p no:warnings --dist loadgroup --numprocesses auto
asyncio_mode = auto
+timeout = 40
[coverage:run]
@@ -564,13 +266,3 @@ omit =
[codespell]
skip=./doc/_build,./doc/source/_static,venv,.venv,.git,htmlcov,CHANGELOG.rst,.mypy_cache
ignore-words-list = asend
-
-[isort]
-skip=doc/_build,venv,.venv,.git,pymodbus/client/serial_asyncio
-py_version=38
-profile=black
-line_length = 79
-lines_after_imports = 2
-known_local_folder =
- common
- contrib
diff --git a/setup.py b/setup.py
index 7cd2ed605..54fb40ea0 100644
--- a/setup.py
+++ b/setup.py
@@ -8,11 +8,11 @@
from setuptools import setup
-dependencies = {}
+dependencies: dict = {}
with open("requirements.txt") as reqs:
option = None
for line in reqs.read().split("\n"):
- if line == "":
+ if not line:
option = None
elif line.startswith("# install:"):
option = line.split(":")[1]
@@ -30,5 +30,7 @@
setup(
install_requires=install_req,
extras_require=dependencies,
- package_data={"pymodbus": ["py.typed"]},
+ package_data={
+ "pymodbus": ["py.typed", "server/simulator/setup.json", "server/simulator/web/**/*"],
+ },
)
diff --git a/test/conftest.py b/test/conftest.py
index f1514a852..2796e4e2b 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,6 +1,7 @@
"""Configure pytest."""
import functools
import platform
+from collections import deque
import pytest
@@ -80,27 +81,17 @@ class mockSocket: # pylint: disable=invalid-name
timeout = 2
- def __init__(self):
+ def __init__(self, copy_send=True):
"""Initialize."""
- self.data = None
+ self.packets = deque()
+ self.buffer = None
self.in_waiting = 0
+ self.copy_send = copy_send
- def mock_store(self, msg):
+ def mock_prepare_receive(self, msg):
"""Store message."""
- self.data = msg
- self.in_waiting = len(self.data)
-
- def mock_retrieve(self, size):
- """Get message."""
- if not self.data or not size:
- return b""
- if size >= len(self.data):
- retval = self.data
- else:
- retval = self.data[0:size]
- self.data = None
- self.in_waiting = 0
- return retval
+ self.packets.append(msg)
+ self.in_waiting += len(msg)
def close(self):
"""Close."""
@@ -108,25 +99,38 @@ def close(self):
def recv(self, size):
"""Receive."""
- return self.mock_retrieve(size)
+ if not self.packets or not size:
+ return b""
+ if not self.buffer:
+ self.buffer = self.packets.popleft()
+ if size >= len(self.buffer):
+ retval = self.buffer
+ self.buffer = None
+ else:
+ retval = self.buffer[0:size]
+ self.buffer = self.buffer[size]
+ self.in_waiting -= len(retval)
+ return retval
def read(self, size):
"""Read."""
- return self.mock_retrieve(size)
+ return self.recv(size)
+
+ def recvfrom(self, size):
+ """Receive from."""
+ return [self.recv(size)]
def send(self, msg):
"""Send."""
- self.mock_store(msg)
+ if not self.copy_send:
+ return len(msg)
+ self.packets.append(msg)
+ self.in_waiting += len(msg)
return len(msg)
- def recvfrom(self, size):
- """Receive from."""
- return [self.mock_retrieve(size)]
-
def sendto(self, msg, *_args):
"""Send to."""
- self.mock_store(msg)
- return len(msg)
+ return self.send(msg)
def setblocking(self, _flag):
"""Set blocking."""
diff --git a/test/test_all_messages.py b/test/test_all_messages.py
index 172ce9da0..0f832e724 100644
--- a/test/test_all_messages.py
+++ b/test/test_all_messages.py
@@ -1,6 +1,4 @@
"""Test all messages."""
-import unittest
-
from pymodbus.bit_read_message import (
ReadCoilsRequest,
ReadCoilsResponse,
@@ -35,69 +33,64 @@
# ---------------------------------------------------------------------------#
-class ModbusAllMessagesTests(unittest.TestCase):
+class TestAllMessages:
"""All messages tests."""
# -----------------------------------------------------------------------#
# Setup/TearDown
# -----------------------------------------------------------------------#
- def setUp(self):
- """Initialize the test environment and builds request/result encoding pairs."""
- arguments = {
- "read_address": 1,
- "read_count": 1,
- "write_address": 1,
- "write_registers": 1,
- }
- self.requests = [
- lambda unit: ReadCoilsRequest(1, 5, unit=unit),
- lambda unit: ReadDiscreteInputsRequest(1, 5, unit=unit),
- lambda unit: WriteSingleCoilRequest(1, 1, unit=unit),
- lambda unit: WriteMultipleCoilsRequest(1, [1], unit=unit),
- lambda unit: ReadHoldingRegistersRequest(1, 5, unit=unit),
- lambda unit: ReadInputRegistersRequest(1, 5, unit=unit),
- lambda unit: ReadWriteMultipleRegistersRequest(unit=unit, **arguments),
- lambda unit: WriteSingleRegisterRequest(1, 1, unit=unit),
- lambda unit: WriteMultipleRegistersRequest(1, [1], unit=unit),
- ]
- self.responses = [
- lambda unit: ReadCoilsResponse([1], unit=unit),
- lambda unit: ReadDiscreteInputsResponse([1], unit=unit),
- lambda unit: WriteSingleCoilResponse(1, 1, unit=unit),
- lambda unit: WriteMultipleCoilsResponse(1, [1], unit=unit),
- lambda unit: ReadHoldingRegistersResponse([1], unit=unit),
- lambda unit: ReadInputRegistersResponse([1], unit=unit),
- lambda unit: ReadWriteMultipleRegistersResponse([1], unit=unit),
- lambda unit: WriteSingleRegisterResponse(1, 1, unit=unit),
- lambda unit: WriteMultipleRegistersResponse(1, 1, unit=unit),
- ]
-
- def tearDown(self):
- """Clean up the test environment"""
+ requests = [
+ lambda slave: ReadCoilsRequest(1, 5, slave=slave),
+ lambda slave: ReadDiscreteInputsRequest(1, 5, slave=slave),
+ lambda slave: WriteSingleCoilRequest(1, 1, slave=slave),
+ lambda slave: WriteMultipleCoilsRequest(1, [1], slave=slave),
+ lambda slave: ReadHoldingRegistersRequest(1, 5, slave=slave),
+ lambda slave: ReadInputRegistersRequest(1, 5, slave=slave),
+ lambda slave: ReadWriteMultipleRegistersRequest(
+ slave=slave,
+ read_address=1,
+ read_count=1,
+ write_address=1,
+ write_registers=1,
+ ),
+ lambda slave: WriteSingleRegisterRequest(1, 1, slave=slave),
+ lambda slave: WriteMultipleRegistersRequest(1, [1], slave=slave),
+ ]
+ responses = [
+ lambda slave: ReadCoilsResponse([1], slave=slave),
+ lambda slave: ReadDiscreteInputsResponse([1], slave=slave),
+ lambda slave: WriteSingleCoilResponse(1, 1, slave=slave),
+ lambda slave: WriteMultipleCoilsResponse(1, [1], slave=slave),
+ lambda slave: ReadHoldingRegistersResponse([1], slave=slave),
+ lambda slave: ReadInputRegistersResponse([1], slave=slave),
+ lambda slave: ReadWriteMultipleRegistersResponse([1], slave=slave),
+ lambda slave: WriteSingleRegisterResponse(1, 1, slave=slave),
+ lambda slave: WriteMultipleRegistersResponse(1, 1, slave=slave),
+ ]
def test_initializing_slave_address_request(self):
- """Test that every request can initialize the unit id"""
- unit_id = 0x12
+ """Test that every request can initialize the slave id"""
+ slave_id = 0x12
for factory in self.requests:
- request = factory(unit_id)
- self.assertEqual(request.unit_id, unit_id)
+ request = factory(slave_id)
+ assert request.slave_id == slave_id
def test_initializing_slave_address_response(self):
- """Test that every response can initialize the unit id"""
- unit_id = 0x12
+ """Test that every response can initialize the slave id"""
+ slave_id = 0x12
for factory in self.responses:
- response = factory(unit_id)
- self.assertEqual(response.unit_id, unit_id)
+ response = factory(slave_id)
+ assert response.slave_id == slave_id
def test_forwarding_kwargs_to_pdu(self):
"""Test that the kwargs are forwarded to the pdu correctly"""
- request = ReadCoilsRequest(1, 5, unit=0x12, transaction=0x12, protocol=0x12)
- self.assertEqual(request.unit_id, 0x12)
- self.assertEqual(request.transaction_id, 0x12)
- self.assertEqual(request.protocol_id, 0x12)
+ request = ReadCoilsRequest(1, 5, slave=0x12, transaction=0x12, protocol=0x12)
+ assert request.slave_id == 0x12
+ assert request.transaction_id == 0x12
+ assert request.protocol_id == 0x12
request = ReadCoilsRequest(1, 5)
- self.assertEqual(request.unit_id, Defaults.Slave)
- self.assertEqual(request.transaction_id, Defaults.TransactionId)
- self.assertEqual(request.protocol_id, Defaults.ProtocolId)
+ assert request.slave_id == Defaults.Slave
+ assert request.transaction_id == Defaults.TransactionId
+ assert request.protocol_id == Defaults.ProtocolId
diff --git a/test/test_bit_read_messages.py b/test/test_bit_read_messages.py
index 5af92a65c..1e00723e3 100644
--- a/test/test_bit_read_messages.py
+++ b/test/test_bit_read_messages.py
@@ -7,7 +7,6 @@
* Read Coils
"""
import struct
-import unittest
from test.conftest import MockContext
from pymodbus.bit_read_message import (
@@ -26,7 +25,7 @@
# ---------------------------------------------------------------------------#
-class ModbusBitMessageTests(unittest.TestCase):
+class TestModbusBitMessage:
"""Modbus bit read message tests."""
# -----------------------------------------------------------------------#
@@ -43,19 +42,19 @@ def test_read_bit_base_class_methods(self):
"""Test basic bit message encoding/decoding"""
handle = ReadBitsRequestBase(1, 1)
msg = "ReadBitRequest(1,1)"
- self.assertEqual(msg, str(handle))
+ assert msg == str(handle)
handle = ReadBitsResponseBase([1, 1])
msg = "ReadBitsResponseBase(2)"
- self.assertEqual(msg, str(handle))
+ assert msg == str(handle)
def test_bit_read_base_request_encoding(self):
"""Test basic bit message encoding/decoding"""
for i in range(20):
handle = ReadBitsRequestBase(i, i)
result = struct.pack(">HH", i, i)
- self.assertEqual(handle.encode(), result)
+ assert handle.encode() == result
handle.decode(result)
- self.assertEqual((handle.address, handle.count), (i, i))
+ assert (handle.address, handle.count) == (i, i)
def test_bit_read_base_response_encoding(self):
"""Test basic bit message encoding/decoding"""
@@ -64,7 +63,7 @@ def test_bit_read_base_response_encoding(self):
handle = ReadBitsResponseBase(data)
result = handle.encode()
handle.decode(result)
- self.assertEqual(handle.bits[:i], data)
+ assert handle.bits[:i] == data
def test_bit_read_base_response_helper_methods(self):
"""Test the extra methods on a ReadBitsResponseBase"""
@@ -75,7 +74,7 @@ def test_bit_read_base_response_helper_methods(self):
for i in (1, 3, 5):
handle.resetBit(i)
for i in range(8):
- self.assertEqual(handle.getBit(i), False)
+ assert not handle.getBit(i)
def test_bit_read_base_requests(self):
"""Test bit read request encoding"""
@@ -84,7 +83,7 @@ def test_bit_read_base_requests(self):
ReadBitsResponseBase([1, 0, 1, 1, 0]): b"\x01\x0d",
}
for request, expected in iter(messages.items()):
- self.assertEqual(request.encode(), expected)
+ assert request.encode() == expected
def test_bit_read_message_execute_value_errors(self):
"""Test bit read request encoding"""
@@ -95,7 +94,7 @@ def test_bit_read_message_execute_value_errors(self):
]
for request in requests:
result = request.execute(context)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
def test_bit_read_message_execute_address_errors(self):
"""Test bit read request encoding"""
@@ -106,7 +105,7 @@ def test_bit_read_message_execute_address_errors(self):
]
for request in requests:
result = request.execute(context)
- self.assertEqual(ModbusExceptions.IllegalAddress, result.exception_code)
+ assert ModbusExceptions.IllegalAddress == result.exception_code
def test_bit_read_message_execute_success(self):
"""Test bit read request encoding"""
@@ -118,7 +117,7 @@ def test_bit_read_message_execute_success(self):
]
for request in requests:
result = request.execute(context)
- self.assertEqual(result.bits, [True] * 5)
+ assert result.bits == [True] * 5
def test_bit_read_message_get_response_pdu(self):
"""Test bit read message get response pdu."""
@@ -132,4 +131,4 @@ def test_bit_read_message_get_response_pdu(self):
}
for request, expected in iter(requests.items()):
pdu_len = request.get_response_pdu_size()
- self.assertEqual(pdu_len, expected)
+ assert pdu_len == expected
diff --git a/test/test_bit_write_messages.py b/test/test_bit_write_messages.py
index d272f4f0d..6a42c42aa 100644
--- a/test/test_bit_write_messages.py
+++ b/test/test_bit_write_messages.py
@@ -6,7 +6,6 @@
* Read/Write Discretes
* Read Coils
"""
-import unittest
from test.conftest import FakeList, MockContext
from pymodbus.bit_write_message import (
@@ -23,7 +22,7 @@
# ---------------------------------------------------------------------------#
-class ModbusBitMessageTests(unittest.TestCase):
+class TestModbusBitMessage:
"""Modbus bit write message tests."""
# -----------------------------------------------------------------------#
@@ -47,56 +46,56 @@ def test_bit_write_base_requests(self):
WriteMultipleCoilsResponse(1, 1): b"\x00\x01\x00\x01",
}
for request, expected in iter(messages.items()):
- self.assertEqual(request.encode(), expected)
+ assert request.encode() == expected
def test_bit_write_message_get_response_pdu(self):
"""Test bit write message."""
requests = {WriteSingleCoilRequest(1, 0xABCD): 5}
for request, expected in iter(requests.items()):
pdu_len = request.get_response_pdu_size()
- self.assertEqual(pdu_len, expected)
+ assert pdu_len == expected
def test_write_multiple_coils_request(self):
"""Test write multiple coils."""
request = WriteMultipleCoilsRequest(1, [True] * 5)
request.decode(b"\x00\x01\x00\x05\x01\x1f")
- self.assertEqual(request.byte_count, 1)
- self.assertEqual(request.address, 1)
- self.assertEqual(request.values, [True] * 5)
- self.assertEqual(request.get_response_pdu_size(), 5)
+ assert request.byte_count == 1
+ assert request.address == 1
+ assert request.values == [True] * 5
+ assert request.get_response_pdu_size() == 5
request = WriteMultipleCoilsRequest(1, True)
request.decode(b"\x00\x01\x00\x01\x01\x01")
- self.assertEqual(request.byte_count, 1)
- self.assertEqual(request.address, 1)
- self.assertEqual(request.values, [True])
- self.assertEqual(request.get_response_pdu_size(), 5)
+ assert request.byte_count == 1
+ assert request.address == 1
+ assert request.values == [True]
+ assert request.get_response_pdu_size() == 5
def test_invalid_write_multiple_coils_request(self):
"""Test write invalid multiple coils."""
request = WriteMultipleCoilsRequest(1, None)
- self.assertEqual(request.values, [])
+ assert request.values == []
def test_write_single_coil_request_encode(self):
"""Test write single coil."""
request = WriteSingleCoilRequest(1, False)
- self.assertEqual(request.encode(), b"\x00\x01\x00\x00")
+ assert request.encode() == b"\x00\x01\x00\x00"
def test_write_single_coil_execute(self):
"""Test write single coil."""
context = MockContext(False, default=True)
request = WriteSingleCoilRequest(2, True)
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalAddress)
+ assert result.exception_code == ModbusExceptions.IllegalAddress
context.valid = True
result = request.execute(context)
- self.assertEqual(result.encode(), b"\x00\x02\xff\x00")
+ assert result.encode() == b"\x00\x02\xff\x00"
context = MockContext(True, default=False)
request = WriteSingleCoilRequest(2, False)
result = request.execute(context)
- self.assertEqual(result.encode(), b"\x00\x02\x00\x00")
+ assert result.encode() == b"\x00\x02\x00\x00"
def test_write_multiple_coils_execute(self):
"""Test write multiple coils."""
@@ -104,31 +103,31 @@ def test_write_multiple_coils_execute(self):
# too many values
request = WriteMultipleCoilsRequest(2, FakeList(0x123456))
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalValue)
+ assert result.exception_code == ModbusExceptions.IllegalValue
# bad byte count
request = WriteMultipleCoilsRequest(2, [0x00] * 4)
request.byte_count = 0x00
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalValue)
+ assert result.exception_code == ModbusExceptions.IllegalValue
# does not validate
context.valid = False
request = WriteMultipleCoilsRequest(2, [0x00] * 4)
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalAddress)
+ assert result.exception_code == ModbusExceptions.IllegalAddress
# validated request
context.valid = True
result = request.execute(context)
- self.assertEqual(result.encode(), b"\x00\x02\x00\x04")
+ assert result.encode() == b"\x00\x02\x00\x04"
def test_write_multiple_coils_response(self):
"""Test write multiple coils."""
response = WriteMultipleCoilsResponse()
response.decode(b"\x00\x80\x00\x08")
- self.assertEqual(response.address, 0x80)
- self.assertEqual(response.count, 0x08)
+ assert response.address == 0x80
+ assert response.count == 0x08
def test_serializing_to_string(self):
"""Test serializing to string."""
@@ -140,4 +139,4 @@ def test_serializing_to_string(self):
]
for request in requests:
result = str(request)
- self.assertTrue(result is not None and len(result))
+ assert result
diff --git a/test/test_client.py b/test/test_client.py
index 757975f96..8d36219fd 100755
--- a/test/test_client.py
+++ b/test/test_client.py
@@ -16,7 +16,7 @@
from pymodbus.client.base import ModbusBaseClient
from pymodbus.client.mixin import ModbusClientMixin
from pymodbus.constants import Defaults
-from pymodbus.exceptions import ConnectionException, NotImplementedException
+from pymodbus.exceptions import ConnectionException
from pymodbus.framer.ascii_framer import ModbusAsciiFramer
from pymodbus.framer.rtu_framer import ModbusRtuFramer
from pymodbus.framer.socket_framer import ModbusSocketFramer
@@ -38,7 +38,7 @@
],
)
@pytest.mark.parametrize(
- "method, arg, pdu_request",
+ ("method", "arg", "pdu_request"),
[
("read_coils", 1, pdu_bit_read.ReadCoilsRequest),
("read_discrete_inputs", 1, pdu_bit_read.ReadDiscreteInputsRequest),
@@ -206,7 +206,7 @@ def fake_execute(_self, request):
],
)
@pytest.mark.parametrize(
- "type_args, clientclass",
+ ("type_args", "clientclass"),
[
# TBD ("serial", lib_client.AsyncModbusSerialClient),
# TBD ("serial", lib_client.ModbusSerialClient),
@@ -239,9 +239,6 @@ async def test_client_instanciate(
to_test = dict(arg_list["fix"]["opt_args"], **cur_args["opt_args"])
to_test["host"] = cur_args["defaults"]["host"]
- for arg, arg_test in to_test.items():
- assert getattr(client.params, arg) == arg_test
-
# Test information methods
client.last_frame_end = 2
client.silent_interval = 2
@@ -249,32 +246,24 @@ async def test_client_instanciate(
client.last_frame_end = None
assert not client.idle_time()
- initial_delay = client.delay_ms
- assert initial_delay > 0
- client.delay_ms *= 2
-
- assert client.delay_ms > initial_delay
- client.reset_delay()
- assert client.delay_ms == initial_delay
-
rc1 = client._get_address_family("127.0.0.1") # pylint: disable=protected-access
- assert socket.AF_INET == rc1
+ assert rc1 == socket.AF_INET
rc2 = client._get_address_family("::1") # pylint: disable=protected-access
- assert socket.AF_INET6 == rc2
+ assert rc2 == socket.AF_INET6
# a successful execute
client.connect = lambda: True
- client._connected = True # pylint: disable=protected-access
+ client.transport = lambda: None
client.transaction = mock.Mock(**{"execute.return_value": True})
# a unsuccessful connect
client.connect = lambda: False
- client._connected = False # pylint: disable=protected-access
+ client.transport = None
with pytest.raises(ConnectionException):
client.execute()
-def test_client_modbusbaseclient():
+async def test_client_modbusbaseclient():
"""Test modbus base client class."""
client = ModbusBaseClient(framer=ModbusAsciiFramer)
client.register(pdu_bit_read.ReadCoilsResponse)
@@ -282,19 +271,11 @@ def test_client_modbusbaseclient():
assert client.send(buffer) == buffer
assert client.recv(10) == 10
- with pytest.raises(NotImplementedException):
- client.connect()
- with pytest.raises(NotImplementedException):
- client.is_socket_open()
- with pytest.raises(NotImplementedException):
- client.close()
-
with mock.patch(
"pymodbus.client.base.ModbusBaseClient.connect"
) as p_connect, mock.patch(
"pymodbus.client.base.ModbusBaseClient.close"
) as p_close:
-
p_connect.return_value = True
p_close.return_value = True
with ModbusBaseClient(framer=ModbusAsciiFramer) as b_client:
@@ -302,42 +283,30 @@ def test_client_modbusbaseclient():
p_connect.return_value = False
-async def test_client_made_connection():
+async def test_client_connection_made():
"""Test protocol made connection."""
client = lib_client.AsyncModbusTcpClient("127.0.0.1")
assert not client.connected
- client.client_made_connection(mock.sentinel.PROTOCOL)
+ client.connection_made(mock.sentinel.PROTOCOL)
assert client.connected
- client.client_made_connection(mock.sentinel.PROTOCOL_UNEXPECTED)
+ client.connection_made(mock.sentinel.PROTOCOL_UNEXPECTED)
assert client.connected
-async def test_client_lost_connection():
+async def test_client_connection_lost():
"""Test protocol lost connection."""
client = lib_client.AsyncModbusTcpClient("127.0.0.1")
assert not client.connected
# fake client is connected and *then* looses connection:
- client.connected = True
client.params.host = mock.sentinel.HOST
client.params.port = mock.sentinel.PORT
- with mock.patch(
- "pymodbus.client.tcp.AsyncModbusTcpClient._launch_reconnect"
- ) as mock_reconnect:
- mock_reconnect.return_value = mock.sentinel.RECONNECT_GENERATOR
-
- client.client_lost_connection(mock.sentinel.PROTOCOL_UNEXPECTED)
+ client.connection_lost(mock.sentinel.PROTOCOL_UNEXPECTED)
assert not client.connected
-
- client.connected = True
- with mock.patch(
- "pymodbus.client.tcp.AsyncModbusTcpClient._launch_reconnect"
- ) as mock_reconnect:
- mock_reconnect.return_value = mock.sentinel.RECONNECT_GENERATOR
-
- client.client_lost_connection(mock.sentinel.PROTOCOL)
+ client.connection_lost(mock.sentinel.PROTOCOL)
assert not client.connected
+ client.close()
async def test_client_base_async():
@@ -347,60 +316,26 @@ async def test_client_base_async():
) as p_connect, mock.patch(
"pymodbus.client.base.ModbusBaseClient.close"
) as p_close:
-
- loop = asyncio.get_event_loop()
- p_connect.return_value = loop.create_future()
+ asyncio.get_event_loop()
+ p_connect.return_value = asyncio.Future()
p_connect.return_value.set_result(True)
- p_close.return_value = loop.create_future()
+ p_close.return_value = asyncio.Future()
p_close.return_value.set_result(True)
async with ModbusBaseClient(framer=ModbusAsciiFramer) as client:
str(client)
- p_connect.return_value = loop.create_future()
+ p_connect.return_value = asyncio.Future()
p_connect.return_value.set_result(False)
- p_close.return_value = loop.create_future()
+ p_close.return_value = asyncio.Future()
p_close.return_value.set_result(False)
-@pytest.mark.skip
-async def test_client_protocol():
- """Test base modbus async client."""
- base = ModbusBaseClient(framer=ModbusSocketFramer)
- assert base.transport is None
- assert not base.async_connected
-
- base.connection_made(mock.sentinel.TRANSPORT)
- assert base.transport is mock.sentinel.TRANSPORT
- base.client_made_connection.assert_called_once_with( # pylint: disable=no-member
- base
- )
- assert not base.client_lost_connection.call_count # pylint: disable=no-member
-
- base.connection_lost(mock.sentinel.REASON)
- assert base.transport is None
- assert not base.client_made_connection.call_count # pylint: disable=no-member
- base.client_lost_connection.assert_called_once_with( # pylint: disable=no-member
- base
- )
- base.raise_future = mock.MagicMock()
- request = mock.MagicMock()
- base.transaction.addTransaction(request, 1)
- base.connection_lost(mock.sentinel.REASON)
- base.raise_future.assert_called_once()
- call_args = base.raise_future.call_args.args
- assert call_args[0] == request
- assert isinstance(call_args[1], ConnectionException)
- base.transport = mock.MagicMock()
- base.transport = None
- await base.async_close()
-
-
async def test_client_protocol_receiver():
"""Test the client protocol data received"""
base = ModbusBaseClient(framer=ModbusSocketFramer)
transport = mock.MagicMock()
base.connection_made(transport)
assert base.transport == transport
- assert base.async_connected
+ assert base.transport
data = b"\x00\x00\x12\x34\x00\x06\xff\x01\x01\x02\x00\x04"
# setup existing request
@@ -410,7 +345,7 @@ async def test_client_protocol_receiver():
result = response.result()
assert isinstance(result, pdu_bit_read.ReadCoilsResponse)
- base._connected = False # pylint: disable=protected-access
+ base.transport = None
with pytest.raises(ConnectionException):
await base._build_response(0x00) # pylint: disable=protected-access
@@ -423,7 +358,7 @@ async def test_client_protocol_response():
assert isinstance(excp, ConnectionException)
assert not list(base.transaction)
- base._connected = True # pylint: disable=protected-access
+ base.transport = lambda: None
base._build_response(0x00) # pylint: disable=protected-access
assert len(list(base.transaction)) == 1
@@ -443,13 +378,10 @@ async def test_client_protocol_handler():
assert result == reply
+@pytest.mark.skip()
async def test_client_protocol_execute():
"""Test the client protocol execute method"""
base = ModbusBaseClient(host="127.0.0.1", framer=ModbusSocketFramer)
- base.create_future = mock.MagicMock()
- fut = asyncio.Future()
- fut.set_result(fut)
- base.create_future.return_value = fut
transport = mock.MagicMock()
base.connection_made(transport)
base.transport.write = mock.Mock()
@@ -478,18 +410,23 @@ def test_client_udp_connect():
"""Test the Udp client connection method"""
with mock.patch.object(socket, "socket") as mock_method:
- class DummySocket: # pylint: disable=too-few-public-methods
+ class DummySocket:
"""Dummy socket."""
+ fileno = 1
+
def settimeout(self, *a, **kwa):
"""Set timeout."""
+ def setblocking(self, _flag):
+ """Set blocking"""
+
mock_method.return_value = DummySocket()
client = lib_client.ModbusUdpClient("127.0.0.1")
assert client.connect()
with mock.patch.object(socket, "socket") as mock_method:
- mock_method.side_effect = socket.error()
+ mock_method.side_effect = OSError()
client = lib_client.ModbusUdpClient("127.0.0.1")
assert not client.connect()
@@ -504,11 +441,29 @@ def test_client_tcp_connect():
assert client.connect()
with mock.patch.object(socket, "create_connection") as mock_method:
- mock_method.side_effect = socket.error()
+ mock_method.side_effect = OSError()
client = lib_client.ModbusTcpClient("127.0.0.1")
assert not client.connect()
+def test_client_tcp_reuse():
+ """Test the tcp client connection method"""
+ with mock.patch.object(socket, "create_connection") as mock_method:
+ _socket = mock.MagicMock()
+ mock_method.return_value = _socket
+ client = lib_client.ModbusTcpClient("127.0.0.1")
+ _socket.getsockname.return_value = ("dmmy", 1234)
+ assert client.connect()
+ client.close()
+ with mock.patch.object(socket, "create_connection") as mock_method:
+ _socket = mock.MagicMock()
+ mock_method.return_value = _socket
+ client = lib_client.ModbusTcpClient("127.0.0.1")
+ _socket.getsockname.return_value = ("dmmy", 1234)
+ assert client.connect()
+ client.close()
+
+
def test_client_tls_connect():
"""Test the tls client connection method"""
with mock.patch.object(ssl.SSLSocket, "connect") as mock_method:
@@ -516,13 +471,13 @@ def test_client_tls_connect():
assert client.connect()
with mock.patch.object(socket, "create_connection") as mock_method:
- mock_method.side_effect = socket.error()
+ mock_method.side_effect = OSError()
client = lib_client.ModbusTlsClient("127.0.0.1")
assert not client.connect()
@pytest.mark.parametrize(
- "datatype,value,registers",
+ ("datatype", "value", "registers"),
[
(ModbusClientMixin.DATATYPE.STRING, "abcd", [0x6162, 0x6364]),
(ModbusClientMixin.DATATYPE.STRING, "a", [0x6100]),
diff --git a/test/test_client_faulty_response.py b/test/test_client_faulty_response.py
new file mode 100644
index 000000000..de15d4633
--- /dev/null
+++ b/test/test_client_faulty_response.py
@@ -0,0 +1,40 @@
+"""Test server working as slave on a multidrop RS485 line."""
+from unittest import mock
+
+import pytest
+
+from pymodbus.exceptions import ModbusIOException
+from pymodbus.factory import ClientDecoder
+from pymodbus.framer import ModbusSocketFramer
+
+
+class TestFaultyResponses:
+ """Test that server works on a multidrop line."""
+
+ slaves = [0]
+
+ good_frame = b"\x00\x01\x00\x00\x00\x05\x00\x03\x02\x00\x01"
+
+ @pytest.fixture(name="framer")
+ def fixture_framer(self):
+ """Prepare framer."""
+ return ModbusSocketFramer(ClientDecoder())
+
+ @pytest.fixture(name="callback")
+ def fixture_callback(self):
+ """Prepare dummy callback."""
+ return mock.Mock()
+
+ def test_ok_frame(self, framer, callback):
+ """Test ok frame."""
+ framer.processIncomingPacket(self.good_frame, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_faulty_frame1(self, framer, callback):
+ """Test ok frame."""
+ faulty_frame = b"\x00\x04\x00\x00\x00\x05\x00\x03\x0a\x00\x04"
+ with pytest.raises(ModbusIOException):
+ framer.processIncomingPacket(faulty_frame, callback, self.slaves)
+ callback.assert_not_called()
+ framer.processIncomingPacket(self.good_frame, callback, self.slaves)
+ callback.assert_called_once()
diff --git a/test/test_client_sync.py b/test/test_client_sync.py
index 3c6845212..d093e1915 100755
--- a/test/test_client_sync.py
+++ b/test/test_client_sync.py
@@ -1,10 +1,10 @@
"""Test client sync."""
import ssl
-import unittest
from itertools import count
from test.conftest import mockSocket
-from unittest.mock import MagicMock, Mock, patch
+from unittest import mock
+import pytest
import serial
from pymodbus.client import (
@@ -29,9 +29,7 @@
# ---------------------------------------------------------------------------#
-class SynchronousClientTest(
- unittest.TestCase
-): # pylint: disable=too-many-public-methods
+class TestSynchronousClient: # pylint: disable=too-many-public-methods
"""Unittest for the pymodbus.client module."""
# -----------------------------------------------------------------------#
@@ -43,49 +41,56 @@ def test_basic_syn_udp_client(self):
# receive/send
client = ModbusUdpClient("127.0.0.1")
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(1, client.send(b"\x50"))
- self.assertEqual(b"\x50", client.recv(1))
+ assert not client.send(None)
+ assert client.send(b"\x50") == 1
+ assert client.recv(1) == b"\x50"
# connect/disconnect
- self.assertTrue(client.connect())
+ assert client.connect()
client.close()
# already closed socket
client.socket = False
client.close()
- self.assertEqual("ModbusUdpClient(127.0.0.1:502)", str(client))
+ assert str(client) == "ModbusUdpClient(127.0.0.1:502)"
def test_udp_client_is_socket_open(self):
"""Test the udp client is_socket_open method"""
client = ModbusUdpClient("127.0.0.1")
- self.assertTrue(client.is_socket_open())
+ assert client.is_socket_open()
def test_udp_client_send(self):
"""Test the udp client send method"""
client = ModbusUdpClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.send(None),
- )
-
+ with pytest.raises(ConnectionException):
+ client.send(None)
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(4, client.send("1234"))
+ assert not client.send(None)
+ assert client.send("1234") == 4
def test_udp_client_recv(self):
"""Test the udp client receive method"""
client = ModbusUdpClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
-
+ with pytest.raises(ConnectionException):
+ client.recv(1024)
client.socket = mockSocket()
- client.socket.mock_store(b"\x00" * 4)
- self.assertEqual(b"", client.recv(0))
- self.assertEqual(b"\x00" * 4, client.recv(4))
+ client.socket.mock_prepare_receive(b"\x00" * 4)
+ assert client.recv(0) == b""
+ assert client.recv(4) == b"\x00" * 4
+
+ def test_udp_client_recv_duplicate(self):
+ """Test the udp client receive method"""
+ test_msg = b"\x00\x01\x00\x00\x00\x05\x01\x04\x02\x00\x03"
+ client = ModbusUdpClient("127.0.0.1")
+ client.socket = mockSocket(copy_send=False)
+ client.socket.mock_prepare_receive(test_msg)
+ client.socket.mock_prepare_receive(test_msg)
+ reply_ok = client.read_input_registers(0x820, 1, 1)
+ assert not reply_ok.isError()
+ reply_none = client.read_input_registers(0x40, 10, 1)
+ assert reply_none.isError()
+ client.close()
def test_udp_client_repr(self):
"""Test udp client representation."""
@@ -94,7 +99,7 @@ def test_udp_client_repr(self):
f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, "
f"ipaddr={client.params.host}, port={client.params.port}, timeout={client.params.timeout}>"
)
- self.assertEqual(repr(client), rep)
+ assert repr(client) == rep
# -----------------------------------------------------------------------#
# Test TCP Client
@@ -103,87 +108,79 @@ def test_udp_client_repr(self):
def test_syn_tcp_client_instantiation(self):
"""Test sync tcp client."""
client = ModbusTcpClient("127.0.0.1")
- self.assertNotEqual(client, None)
+ assert client
- @patch("pymodbus.client.tcp.select")
+ @mock.patch("pymodbus.client.tcp.select")
def test_basic_syn_tcp_client(self, mock_select):
"""Test the basic methods for the tcp sync client"""
# receive/send
mock_select.select.return_value = [True]
client = ModbusTcpClient("127.0.0.1")
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(1, client.send(b"\x45"))
- self.assertEqual(b"\x45", client.recv(1))
+ assert not client.send(None)
+ assert client.send(b"\x45") == 1
+ assert client.recv(1) == b"\x45"
# connect/disconnect
- self.assertTrue(client.connect())
+ assert client.connect()
client.close()
# already closed socket
client.socket = False
client.close()
- self.assertEqual("ModbusTcpClient(127.0.0.1:502)", str(client))
+ assert str(client) == "ModbusTcpClient(127.0.0.1:502)"
def test_tcp_client_is_socket_open(self):
"""Test the tcp client is_socket_open method"""
client = ModbusTcpClient("127.0.0.1")
- self.assertFalse(client.is_socket_open())
+ assert not client.is_socket_open()
def test_tcp_client_send(self):
"""Test the tcp client send method"""
client = ModbusTcpClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.send(None),
- )
-
+ with pytest.raises(ConnectionException):
+ client.send(None)
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(4, client.send("1234"))
+ assert not client.send(None)
+ assert client.send("1234") == 4
- @patch("pymodbus.client.tcp.time")
- @patch("pymodbus.client.tcp.select")
+ @mock.patch("pymodbus.client.tcp.time")
+ @mock.patch("pymodbus.client.tcp.select")
def test_tcp_client_recv(self, mock_select, mock_time):
"""Test the tcp client receive method"""
mock_select.select.return_value = [True]
mock_time.time.side_effect = count()
client = ModbusTcpClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
+ with pytest.raises(ConnectionException):
+ client.recv(1024)
client.socket = mockSocket()
- self.assertEqual(b"", client.recv(0))
- client.socket.mock_store(b"\x00" * 4)
- self.assertEqual(b"\x00" * 4, client.recv(4))
+ assert client.recv(0) == b""
+ client.socket.mock_prepare_receive(b"\x00" * 4)
+ assert client.recv(4) == b"\x00" * 4
- mock_socket = MagicMock()
+ mock_socket = mock.MagicMock()
mock_socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"])
client.socket = mock_socket
client.params.timeout = 3
- self.assertEqual(b"\x00\x01\x02", client.recv(3))
+ assert client.recv(3) == b"\x00\x01\x02"
mock_socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"])
- self.assertEqual(b"\x00\x01", client.recv(2))
+ assert client.recv(2) == b"\x00\x01"
mock_select.select.return_value = [False]
- self.assertEqual(b"", client.recv(2))
+ assert client.recv(2) == b""
client.socket = mockSocket()
- client.socket.mock_store(b"\x00")
+ client.socket.mock_prepare_receive(b"\x00")
mock_select.select.return_value = [True]
- self.assertIn(b"\x00", client.recv(None))
+ assert client.recv(None) in b"\x00"
- mock_socket = MagicMock()
+ mock_socket = mock.MagicMock()
mock_socket.recv.return_value = b""
client.socket = mock_socket
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
-
+ with pytest.raises(ConnectionException):
+ client.recv(1024)
mock_socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02", b""])
client.socket = mock_socket
- self.assertEqual(b"\x00\x01\x02", client.recv(1024))
+ assert client.recv(1024) == b"\x00\x01\x02"
def test_tcp_client_repr(self):
"""Test tcp client."""
@@ -192,7 +189,7 @@ def test_tcp_client_repr(self):
f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, "
f"ipaddr={client.params.host}, port={client.params.port}, timeout={client.params.timeout}>"
)
- self.assertEqual(repr(client), rep)
+ assert repr(client) == rep
def test_tcp_client_register(self):
"""Test tcp client."""
@@ -203,9 +200,9 @@ class CustomRequest: # pylint: disable=too-few-public-methods
function_code = 79
client = ModbusTcpClient("127.0.0.1")
- client.framer = Mock()
+ client.framer = mock.Mock()
client.register(CustomRequest)
- self.assertTrue(client.framer.decoder.register.called_once_with(CustomRequest))
+ assert client.framer.decoder.register.called_once_with(CustomRequest)
# -----------------------------------------------------------------------#
# Test TLS Client
@@ -213,104 +210,87 @@ class CustomRequest: # pylint: disable=too-few-public-methods
def test_tls_sslctx_provider(self):
"""Test that sslctx_provider() produce SSLContext correctly"""
- with patch.object(ssl.SSLContext, "load_cert_chain") as mock_method:
+ with mock.patch.object(ssl.SSLContext, "load_cert_chain") as mock_method:
sslctx1 = sslctx_provider(certfile="cert.pem")
- self.assertIsNotNone(sslctx1)
- self.assertEqual(type(sslctx1), ssl.SSLContext)
- self.assertEqual(mock_method.called, False)
+ assert sslctx1
+ assert isinstance(sslctx1, ssl.SSLContext)
+ assert not mock_method.called
sslctx2 = sslctx_provider(keyfile="key.pem")
- self.assertIsNotNone(sslctx2)
- self.assertEqual(type(sslctx2), ssl.SSLContext)
- self.assertEqual(mock_method.called, False)
+ assert sslctx2
+ assert isinstance(sslctx2, ssl.SSLContext)
+ assert not mock_method.called
sslctx3 = sslctx_provider(certfile="cert.pem", keyfile="key.pem")
- self.assertIsNotNone(sslctx3)
- self.assertEqual(type(sslctx3), ssl.SSLContext)
- self.assertEqual(mock_method.called, True)
+ assert sslctx3
+ assert isinstance(sslctx3, ssl.SSLContext)
+ assert mock_method.called
sslctx_old = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
sslctx_new = sslctx_provider(sslctx=sslctx_old)
- self.assertEqual(sslctx_new, sslctx_old)
+ assert sslctx_new == sslctx_old
def test_syn_tls_client_instantiation(self):
"""Test sync tls client."""
# default SSLContext
client = ModbusTlsClient("127.0.0.1")
- self.assertNotEqual(client, None)
- self.assertIsInstance(client.framer, ModbusTlsFramer)
- self.assertTrue(client.sslctx)
+ assert client
+ assert isinstance(client.framer, ModbusTlsFramer)
+ assert client.sslctx
- @patch("pymodbus.client.tcp.select")
+ @mock.patch("pymodbus.client.tcp.select")
def test_basic_syn_tls_client(self, mock_select):
"""Test the basic methods for the tls sync client"""
# receive/send
mock_select.select.return_value = [True]
client = ModbusTlsClient("localhost")
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(1, client.send(b"\x45"))
- self.assertEqual(b"\x45", client.recv(1))
+ assert not client.send(None)
+ assert client.send(b"\x45") == 1
+ assert client.recv(1) == b"\x45"
# connect/disconnect
- self.assertTrue(client.connect())
+ assert client.connect()
client.close()
# already closed socket
client.socket = False
client.close()
- self.assertEqual("ModbusTlsClient(localhost:802)", str(client))
+ assert str(client) == "ModbusTlsClient(localhost:802)"
client = ModbusTcpClient("127.0.0.1")
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(1, client.send(b"\x45"))
- self.assertEqual(b"\x45", client.recv(1))
+ assert not client.send(None)
+ assert client.send(b"\x45") == 1
+ assert client.recv(1) == b"\x45"
def test_tls_client_send(self):
"""Test the tls client send method"""
client = ModbusTlsClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.send(None),
- )
-
+ with pytest.raises(ConnectionException):
+ client.send(None)
client.socket = mockSocket()
- self.assertEqual(0, client.send(None))
- self.assertEqual(4, client.send("1234"))
+ assert not client.send(None)
+ assert client.send("1234") == 4
- @patch("pymodbus.client.tcp.time")
- @patch("pymodbus.client.tcp.select")
+ @mock.patch("pymodbus.client.tcp.time")
+ @mock.patch("pymodbus.client.tcp.select")
def test_tls_client_recv(self, mock_select, mock_time):
"""Test the tls client receive method"""
mock_select.select.return_value = [True]
client = ModbusTlsClient("127.0.0.1")
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
-
+ with pytest.raises(ConnectionException):
+ client.recv(1024)
mock_time.time.side_effect = count()
client.socket = mockSocket()
- client.socket.mock_store(b"\x00" * 4)
- self.assertEqual(b"", client.recv(0))
- self.assertEqual(b"\x00" * 4, client.recv(4))
+ client.socket.mock_prepare_receive(b"\x00" * 4)
+ assert client.recv(0) == b""
+ assert client.recv(4) == b"\x00" * 4
client.params.timeout = 2
- client.socket.mock_store(b"\x00")
- self.assertIn(b"\x00", client.recv(None))
-
- # client.socket = mockSocket()
- # client.socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"])
- # client.params.timeout = 3
- # self.assertEqual(
- # b"\x00\x01\x02", client.recv(3)
- # )
- # client.socket.recv.side_effect = iter([b"\x00", b"\x01", b"\x02"])
- # self.assertEqual(
- # b"\x00\x01", client.recv(2)
- # )
+ client.socket.mock_prepare_receive(b"\x00")
+ assert b"\x00" in client.recv(None)
def test_tls_client_repr(self):
"""Test tls client."""
@@ -320,7 +300,7 @@ def test_tls_client_repr(self):
f"ipaddr={client.params.host}, port={client.params.port}, sslctx={client.sslctx}, "
f"timeout={client.params.timeout}>"
)
- self.assertEqual(repr(client), rep)
+ assert repr(client) == rep
def test_tls_client_register(self):
"""Test tls client."""
@@ -331,9 +311,9 @@ class CustomeRequest: # pylint: disable=too-few-public-methods
function_code = 79
client = ModbusTlsClient("127.0.0.1")
- client.framer = Mock()
+ client.framer = mock.Mock()
client.register(CustomeRequest)
- self.assertTrue(client.framer.decoder.register.called_once_with(CustomeRequest))
+ assert client.framer.decoder.register.called_once_with(CustomeRequest)
# -----------------------------------------------------------------------#
# Test Serial Client
@@ -341,40 +321,32 @@ class CustomeRequest: # pylint: disable=too-few-public-methods
def test_sync_serial_client_instantiation(self):
"""Test sync serial client."""
client = ModbusSerialClient("/dev/null")
- self.assertNotEqual(client, None)
- self.assertTrue(
- isinstance(
- ModbusSerialClient("/dev/null", framer=ModbusAsciiFramer).framer,
- ModbusAsciiFramer,
- )
+ assert client
+ assert isinstance(
+ ModbusSerialClient("/dev/null", framer=ModbusAsciiFramer).framer,
+ ModbusAsciiFramer,
)
- self.assertTrue(
- isinstance(
- ModbusSerialClient("/dev/null", framer=ModbusRtuFramer).framer,
- ModbusRtuFramer,
- )
+ assert isinstance(
+ ModbusSerialClient("/dev/null", framer=ModbusRtuFramer).framer,
+ ModbusRtuFramer,
)
- self.assertTrue(
- isinstance(
- ModbusSerialClient("/dev/null", framer=ModbusBinaryFramer).framer,
- ModbusBinaryFramer,
- )
+ assert isinstance(
+ ModbusSerialClient("/dev/null", framer=ModbusBinaryFramer).framer,
+ ModbusBinaryFramer,
)
- self.assertTrue(
- isinstance(
- ModbusSerialClient("/dev/null", framer=ModbusSocketFramer).framer,
- ModbusSocketFramer,
- )
+ assert isinstance(
+ ModbusSerialClient("/dev/null", framer=ModbusSocketFramer).framer,
+ ModbusSocketFramer,
)
def test_sync_serial_rtu_client_timeouts(self):
"""Test sync serial rtu."""
client = ModbusSerialClient("/dev/null", framer=ModbusRtuFramer, baudrate=9600)
- self.assertEqual(client.silent_interval, round((3.5 * 11 / 9600), 6))
+ assert client.silent_interval == round((3.5 * 11 / 9600), 6)
client = ModbusSerialClient("/dev/null", framer=ModbusRtuFramer, baudrate=38400)
- self.assertEqual(client.silent_interval, round((1.75 / 1000), 6))
+ assert client.silent_interval == round((1.75 / 1000), 6)
- @patch("serial.Serial")
+ @mock.patch("serial.Serial")
def test_basic_sync_serial_client(self, mock_serial):
"""Test the basic methods for the serial sync client."""
# receive/send
@@ -385,25 +357,23 @@ def test_basic_sync_serial_client(self, mock_serial):
client = ModbusSerialClient("/dev/null")
client.socket = mock_serial
client.state = 0
- self.assertEqual(0, client.send(None))
+ assert not client.send(None)
client.state = 0
- self.assertEqual(1, client.send(b"\x00"))
- self.assertEqual(b"\x00", client.recv(1))
+ assert client.send(b"\x00") == 1
+ assert client.recv(1) == b"\x00"
# connect/disconnect
- self.assertTrue(client.connect())
+ assert client.connect()
client.close()
# rtu connect/disconnect
rtu_client = ModbusSerialClient(
"/dev/null", framer=ModbusRtuFramer, strict=True
)
- self.assertTrue(rtu_client.connect())
- self.assertEqual(
- rtu_client.socket.interCharTimeout, rtu_client.inter_char_timeout
- )
+ assert rtu_client.connect()
+ assert rtu_client.socket.interCharTimeout == rtu_client.inter_char_timeout
rtu_client.close()
- self.assertTrue("baud[19200])" in str(client))
+ assert "baud[19200])" in str(client)
# already closed socket
client.socket = False
@@ -411,76 +381,67 @@ def test_basic_sync_serial_client(self, mock_serial):
def test_serial_client_connect(self):
"""Test the serial client connection method"""
- with patch.object(serial, "Serial") as mock_method:
- mock_method.return_value = MagicMock()
+ with mock.patch.object(serial, "Serial") as mock_method:
+ mock_method.return_value = mock.MagicMock()
client = ModbusSerialClient("/dev/null")
- self.assertTrue(client.connect())
+ assert client.connect()
- with patch.object(serial, "Serial") as mock_method:
+ with mock.patch.object(serial, "Serial") as mock_method:
mock_method.side_effect = serial.SerialException()
client = ModbusSerialClient("/dev/null")
- self.assertFalse(client.connect())
+ assert not client.connect()
- @patch("serial.Serial")
+ @mock.patch("serial.Serial")
def test_serial_client_is_socket_open(self, mock_serial):
"""Test the serial client is_socket_open method"""
client = ModbusSerialClient("/dev/null")
- self.assertFalse(client.is_socket_open())
+ assert not client.is_socket_open()
client.socket = mock_serial
- self.assertTrue(client.is_socket_open())
+ assert client.is_socket_open()
- @patch("serial.Serial")
+ @mock.patch("serial.Serial")
def test_serial_client_send(self, mock_serial):
"""Test the serial client send method"""
mock_serial.in_waiting = None
mock_serial.write = lambda x: len(x) # pylint: disable=unnecessary-lambda
client = ModbusSerialClient("/dev/null")
- self.assertRaises(
- ConnectionException,
- lambda: client.send(None),
- )
- # client.connect()
+ with pytest.raises(ConnectionException):
+ client.send(None)
client.socket = mock_serial
client.state = 0
- self.assertEqual(0, client.send(None))
+ assert not client.send(None)
client.state = 0
- self.assertEqual(4, client.send("1234"))
+ assert client.send("1234") == 4
- @patch("serial.Serial")
+ @mock.patch("serial.Serial")
def test_serial_client_cleanup_buffer_before_send(self, mock_serial):
"""Test the serial client send method"""
mock_serial.in_waiting = 4
mock_serial.read = lambda x: b"1" * x
mock_serial.write = lambda x: len(x) # pylint: disable=unnecessary-lambda
client = ModbusSerialClient("/dev/null")
- self.assertRaises(
- ConnectionException,
- lambda: client.send(None),
- )
- # client.connect()
+ with pytest.raises(ConnectionException):
+ client.send(None)
client.socket = mock_serial
client.state = 0
- self.assertEqual(0, client.send(None))
+ assert not client.send(None)
client.state = 0
- self.assertEqual(4, client.send("1234"))
+ assert client.send("1234") == 4
def test_serial_client_recv(self):
"""Test the serial client receive method"""
client = ModbusSerialClient("/dev/null")
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
-
+ with pytest.raises(ConnectionException):
+ client.recv(1024)
client.socket = mockSocket()
- self.assertEqual(b"", client.recv(0))
- client.socket.mock_store(b"\x00" * 4)
- self.assertEqual(b"\x00" * 4, client.recv(4))
+ assert client.recv(0) == b""
+ client.socket.mock_prepare_receive(b"\x00" * 4)
+ assert client.recv(4) == b"\x00" * 4
client.socket = mockSocket()
- client.socket.mock_store(b"")
- self.assertEqual(b"", client.recv(None))
+ client.socket.mock_prepare_receive(b"")
+ assert client.recv(None) == b""
client.socket.timeout = 0
- self.assertEqual(b"", client.recv(0))
+ assert client.recv(0) == b""
def test_serial_client_repr(self):
"""Test serial client."""
@@ -489,4 +450,4 @@ def test_serial_client_repr(self):
f"<{client.__class__.__name__} at {hex(id(client))} socket={client.socket}, "
f"framer={client.framer}, timeout={client.params.timeout}>"
)
- self.assertEqual(repr(client), rep)
+ assert repr(client) == rep
diff --git a/test/test_client_sync_diag.py b/test/test_client_sync_diag.py
deleted file mode 100755
index 69d8508f3..000000000
--- a/test/test_client_sync_diag.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""Test client sync diag."""
-import socket
-import unittest
-from itertools import count
-from test.test_client_sync import mockSocket
-from unittest.mock import MagicMock, patch
-
-from pymodbus.client.sync_diag import ModbusTcpDiagClient, get_client
-from pymodbus.exceptions import ConnectionException
-
-
-# ---------------------------------------------------------------------------#
-# Fixture
-# ---------------------------------------------------------------------------#
-
-
-class SynchronousDiagnosticClientTest(unittest.TestCase):
- """Unittest for the pymodbus.client.sync_diag module.
-
- It is a copy of parts of the test for the TCP class in the pymodbus.client
- module, as it should operate identically and only log some additional
- lines.
- """
-
- # -----------------------------------------------------------------------#
- # Test TCP Diagnostic Client
- # -----------------------------------------------------------------------#
-
- def test_syn_tcp_diag_client_instantiation(self):
- """Test sync tcp diag client."""
- client = get_client()
- self.assertNotEqual(client, None)
-
- def test_basic_syn_tcp_diag_client(self):
- """Test the basic methods for the tcp sync diag client"""
- # connect/disconnect
- client = ModbusTcpDiagClient()
- client.socket = mockSocket()
- self.assertTrue(client.connect())
- client.close()
-
- def test_tcp_diag_client_connect(self):
- """Test the tcp sync diag client connection method"""
- with patch.object(socket, "create_connection") as mock_method:
- mock_method.return_value = object()
- client = ModbusTcpDiagClient()
- self.assertTrue(client.connect())
-
- with patch.object(socket, "create_connection") as mock_method:
- mock_method.side_effect = socket.error()
- client = ModbusTcpDiagClient()
- self.assertFalse(client.connect())
-
- @patch("pymodbus.client.tcp.time")
- @patch("pymodbus.client.sync_diag.time")
- @patch("pymodbus.client.tcp.select")
- def test_tcp_diag_client_recv(self, mock_select, mock_diag_time, mock_time):
- """Test the tcp sync diag client receive method"""
- mock_select.select.return_value = [True]
- mock_time.time.side_effect = count()
- mock_diag_time.time.side_effect = count()
- client = ModbusTcpDiagClient()
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
-
- client.socket = mockSocket()
- # Test logging of non-delayed responses
- client.socket.mock_store(b"\x00")
- self.assertIn(b"\x00", client.recv(None))
- client.socket = mockSocket()
- client.socket.mock_store(b"\x00")
- self.assertEqual(b"\x00", client.recv(1))
-
- # Fool diagnostic logger into thinking we"re running late,
- # test logging of delayed responses
- mock_diag_time.time.side_effect = count(step=3)
- client.socket.mock_store(b"\x00" * 4)
- self.assertEqual(b"\x00" * 4, client.recv(4))
- self.assertEqual(b"", client.recv(0))
-
- client.socket.mock_store(b"\x00\x01\x02")
- client.timeout = 3
- self.assertEqual(b"\x00\x01\x02", client.recv(3))
- client.socket.mock_store(b"\x00\x01\x02")
- self.assertEqual(b"\x00\x01", client.recv(2))
- mock_select.select.return_value = [False]
- self.assertEqual(b"", client.recv(2))
- client.socket = mockSocket()
- client.socket.mock_store(b"\x00")
- mock_select.select.return_value = [True]
- self.assertIn(b"\x00", client.recv(None))
-
- mock_socket = MagicMock()
- client.socket = mock_socket
- mock_socket.recv.return_value = b""
- self.assertRaises(
- ConnectionException,
- lambda: client.recv(1024),
- )
- client.socket = mockSocket()
- client.socket.mock_store(b"\x00\x01\x02")
- self.assertEqual(b"\x00\x01\x02", client.recv(1024))
-
- def test_tcp_diag_client_repr(self):
- """Test tcp diag client."""
- client = ModbusTcpDiagClient()
- rep = (
- f"<{client.__class__.__name__} at {hex(id(client))} "
- f"socket={client.socket}, ipaddr={client.params.host}, "
- f"port={client.params.port}, timeout={client.params.timeout}>"
- )
- self.assertEqual(repr(client), rep)
diff --git a/test/test_datastore.py b/test/test_datastore.py
deleted file mode 100644
index c1bf26c48..000000000
--- a/test/test_datastore.py
+++ /dev/null
@@ -1,504 +0,0 @@
-"""Test datastore."""
-import random
-import unittest
-from unittest.mock import MagicMock
-
-import pytest
-import redis
-
-from pymodbus.datastore import (
- ModbusSequentialDataBlock,
- ModbusServerContext,
- ModbusSlaveContext,
- ModbusSparseDataBlock,
-)
-from pymodbus.datastore.database import RedisSlaveContext, SqlSlaveContext
-from pymodbus.datastore.store import BaseModbusDataBlock
-from pymodbus.exceptions import (
- NoSuchSlaveException,
- NotImplementedException,
- ParameterException,
-)
-
-
-class ModbusDataStoreTest(unittest.TestCase):
- """Unittest for the pymodbus.datastore module."""
-
- def setUp(self):
- """Do setup."""
-
- def tearDown(self):
- """Clean up the test environment"""
-
- def test_modbus_data_block(self):
- """Test a base data block store"""
- block = BaseModbusDataBlock()
- block.default(10, True)
-
- self.assertNotEqual(str(block), None)
- self.assertEqual(block.default_value, True)
- self.assertEqual(block.values, [True] * 10)
-
- block.default_value = False
- block.reset()
- self.assertEqual(block.values, [False] * 10)
-
- def test_modbus_data_block_iterate(self):
- """Test a base data block store"""
- block = BaseModbusDataBlock()
- block.default(10, False)
- for _, value in block:
- self.assertEqual(value, False)
-
- block.values = {0: False, 2: False, 3: False}
- for _, value in block:
- self.assertEqual(value, False)
-
- def test_modbus_data_block_other(self):
- """Test a base data block store"""
- block = BaseModbusDataBlock()
- self.assertRaises(NotImplementedException, lambda: block.validate(1, 1))
- self.assertRaises(NotImplementedException, lambda: block.getValues(1, 1))
- self.assertRaises(NotImplementedException, lambda: block.setValues(1, 1))
-
- def test_modbus_sequential_data_block(self):
- """Test a sequential data block store"""
- block = ModbusSequentialDataBlock(0x00, [False] * 10)
- self.assertFalse(block.validate(-1, 0))
- self.assertFalse(block.validate(0, 20))
- self.assertFalse(block.validate(10, 1))
- self.assertTrue(block.validate(0x00, 10))
-
- block.setValues(0x00, True)
- self.assertEqual(block.getValues(0x00, 1), [True])
-
- block.setValues(0x00, [True] * 10)
- self.assertEqual(block.getValues(0x00, 10), [True] * 10)
-
- def test_modbus_sequential_data_block_factory(self):
- """Test the sequential data block store factory"""
- block = ModbusSequentialDataBlock.create()
- self.assertEqual(block.getValues(0x00, 65536), [False] * 65536)
- block = ModbusSequentialDataBlock(0x00, 0x01)
- self.assertEqual(block.values, [0x01])
-
- def test_modbus_sparse_data_block(self):
- """Test a sparse data block store"""
- values = dict(enumerate([True] * 10))
- block = ModbusSparseDataBlock(values)
- self.assertFalse(block.validate(-1, 0))
- self.assertFalse(block.validate(0, 20))
- self.assertFalse(block.validate(10, 1))
- self.assertTrue(block.validate(0x00, 10))
- self.assertTrue(block.validate(0x00, 10))
- self.assertFalse(block.validate(0, 0))
- self.assertFalse(block.validate(5, 0))
-
- block.setValues(0x00, True)
- self.assertEqual(block.getValues(0x00, 1), [True])
-
- block.setValues(0x00, [True] * 10)
- self.assertEqual(block.getValues(0x00, 10), [True] * 10)
-
- block.setValues(0x00, dict(enumerate([False] * 10)))
- self.assertEqual(block.getValues(0x00, 10), [False] * 10)
-
- block = ModbusSparseDataBlock({3: [10, 11, 12], 10: 1, 15: [0] * 4})
- self.assertEqual(
- block.values, {3: 10, 4: 11, 5: 12, 10: 1, 15: 0, 16: 0, 17: 0, 18: 0}
- )
- self.assertEqual(
- block.default_value,
- {3: 10, 4: 11, 5: 12, 10: 1, 15: 0, 16: 0, 17: 0, 18: 0},
- )
- self.assertEqual(block.mutable, True)
- block.setValues(3, [20, 21, 22, 23], use_as_default=True)
- self.assertEqual(block.getValues(3, 4), [20, 21, 22, 23])
- self.assertEqual(
- block.default_value,
- {3: 20, 4: 21, 5: 22, 6: 23, 10: 1, 15: 0, 16: 0, 17: 0, 18: 0},
- )
- # check when values is a dict, address is ignored
- block.setValues(0, {5: 32, 7: 43})
- self.assertEqual(block.getValues(5, 3), [32, 23, 43])
-
- # assertEqual value is empty dict when initialized without params
- block = ModbusSparseDataBlock()
- self.assertEqual(block.values, {})
-
- # mark block as unmutable and see if parameter exception
- # is raised for invalid offset writes
- block = ModbusSparseDataBlock({1: 100}, mutable=False)
- self.assertRaises(ParameterException, block.setValues, 0, 1)
- self.assertRaises(ParameterException, block.setValues, 0, {2: 100})
- self.assertRaises(ParameterException, block.setValues, 0, [1] * 10)
-
- # Reset datablock
- block = ModbusSparseDataBlock({3: [10, 11, 12], 10: 1, 15: [0] * 4})
- block.setValues(0, {3: [20, 21, 22], 10: 11, 15: [10] * 4})
- self.assertEqual(
- block.values, {3: 20, 4: 21, 5: 22, 10: 11, 15: 10, 16: 10, 17: 10, 18: 10}
- )
- block.reset()
- self.assertEqual(
- block.values, {3: 10, 4: 11, 5: 12, 10: 1, 15: 0, 16: 0, 17: 0, 18: 0}
- )
-
- def test_modbus_sparse_data_block_factory(self):
- """Test the sparse data block store factory"""
- block = ModbusSparseDataBlock.create([0x00] * 65536)
- self.assertEqual(block.getValues(0x00, 65536), [False] * 65536)
-
- def test_modbus_sparse_data_block_other(self):
- """Test modbus sparce data block."""
- block = ModbusSparseDataBlock([True] * 10)
- self.assertEqual(block.getValues(0x00, 10), [True] * 10)
- self.assertRaises(ParameterException, lambda: ModbusSparseDataBlock(True))
-
- def test_modbus_slave_context(self):
- """Test a modbus slave context"""
- store = {
- "di": ModbusSequentialDataBlock(0, [False] * 10),
- "co": ModbusSequentialDataBlock(0, [False] * 10),
- "ir": ModbusSequentialDataBlock(0, [False] * 10),
- "hr": ModbusSequentialDataBlock(0, [False] * 10),
- }
- context = ModbusSlaveContext(**store)
- self.assertNotEqual(str(context), None)
-
- for i in (1, 2, 3, 4):
- context.setValues(i, 0, [True] * 10)
- self.assertTrue(context.validate(i, 0, 10))
- self.assertEqual(context.getValues(i, 0, 10), [True] * 10)
- context.reset()
-
- for i in (1, 2, 3, 4):
- self.assertTrue(context.validate(i, 0, 10))
- self.assertEqual(context.getValues(i, 0, 10), [False] * 10)
-
- def test_modbus_server_context(self):
- """Test a modbus server context"""
-
- def _set(ctx):
- ctx[0xFFFF] = None
-
- context = ModbusServerContext(single=False)
- self.assertRaises(NoSuchSlaveException, lambda: _set(context))
- self.assertRaises(NoSuchSlaveException, lambda: context[0xFFFF])
-
-
-class RedisDataStoreTest(unittest.TestCase):
- """Unittest for the pymodbus.datastore.database.redis module."""
-
- def setUp(self):
- """Do setup."""
- self.slave = RedisSlaveContext()
-
- def tearDown(self):
- """Clean up the test environment"""
-
- def test_str(self):
- """Test string."""
- # slave = RedisSlaveContext()
- self.assertEqual(str(self.slave), f"Redis Slave Context {self.slave.client}")
-
- def test_reset(self):
- """Test reset."""
- self.assertTrue(isinstance(self.slave.client, redis.Redis))
- self.slave.client = MagicMock()
- self.slave.reset()
- self.slave.client.flushall.assert_called_once_with()
-
- def test_val_callbacks_success(self):
- """Test value callbacks success."""
- self.slave._build_mapping() # pylint: disable=protected-access
- mock_count = 3
- mock_offset = 0
- self.slave.client.mset = MagicMock()
- self.slave.client.mget = MagicMock(return_value=["11"])
-
- for key in ("d", "c", "h", "i"):
- self.assertTrue(
- self.slave._val_callbacks[key]( # pylint: disable=protected-access
- mock_offset, mock_count
- )
- )
-
- def test_val_callbacks_failure(self):
- """Test value callbacks failure."""
- self.slave._build_mapping() # pylint: disable=protected-access
- mock_count = 3
- mock_offset = 0
- self.slave.client.mset = MagicMock()
- self.slave.client.mget = MagicMock(return_value=["11", None])
-
- for key in ("d", "c", "h", "i"):
- self.assertFalse(
- self.slave._val_callbacks[key]( # pylint: disable=protected-access
- mock_offset, mock_count
- )
- )
-
- def test_get_callbacks(self):
- """Test get callbacks."""
- self.slave._build_mapping() # pylint: disable=protected-access
- mock_count = 3
- mock_offset = 0
- self.slave.client.mget = MagicMock(return_value="11")
-
- for key in ("d", "c"):
- resp = self.slave._get_callbacks[key]( # pylint: disable=protected-access
- mock_offset, mock_count
- )
- self.assertEqual(resp, [True, False, False])
-
- for key in ("h", "i"):
- resp = self.slave._get_callbacks[key]( # pylint: disable=protected-access
- mock_offset, mock_count
- )
- self.assertEqual(resp, ["1", "1"])
-
- def test_set_callbacks(self):
- """Test set callbacks."""
- self.slave._build_mapping() # pylint: disable=protected-access
- mock_values = [3]
- mock_offset = 0
- self.slave.client.mset = MagicMock()
- self.slave.client.mget = MagicMock()
-
- for key in ("c", "d"):
- self.slave._set_callbacks[key]( # pylint: disable=protected-access
- mock_offset, [3]
- )
- k = f"pymodbus:{key}:{mock_offset}"
- self.slave.client.mset.assert_called_with({k: "\x01"})
-
- for key in ("h", "i"):
- self.slave._set_callbacks[key]( # pylint: disable=protected-access
- mock_offset, [3]
- )
- k = f"pymodbus:{key}:{mock_offset}"
- self.slave.client.mset.assert_called_with({k: mock_values[0]})
-
- def test_validate(self):
- """Test validate."""
- self.slave.client.mget = MagicMock(return_value=[123])
- self.assertTrue(self.slave.validate(0x01, 3000))
-
- def test_set_value(self):
- """Test set value."""
- self.slave.client.mset = MagicMock()
- self.slave.client.mget = MagicMock()
- self.assertEqual(self.slave.setValues(0x01, 1000, [12]), None)
-
- def test_get_value(self):
- """Test get value."""
- self.slave.client.mget = MagicMock(return_value=["123"])
- self.assertEqual(self.slave.getValues(0x01, 23), [])
-
-
-class MockSqlResult: # pylint: disable=too-few-public-methods
- """Mock SQL Result."""
-
- def __init__(self, rowcount=0, value=0):
- """Initialize."""
- self.rowcount = rowcount
- self.value = value
-
-
-class SqlDataStoreTest(unittest.TestCase):
- """Unittest for the pymodbus.datastore.database.SqlSlaveContext module."""
-
- class SQLunit: # pylint: disable=too-few-public-methods
- """Single test setup."""
-
- def __init__(self):
- """Prepare test."""
- self.slave = SqlSlaveContext()
- self.slave._metadata.drop_all = MagicMock()
- self.slave._db_create = MagicMock()
- self.slave._table.select = MagicMock()
- self.slave._connection = MagicMock()
-
- self.mock_addr = random.randint(0, 65000)
- self.mock_values = random.sample(range(1, 100), 5)
- self.mock_function = 0x01
- self.mock_type = "h"
- self.mock_offset = 0
- self.mock_count = 1
-
- self.function_map = {2: "d", 4: "i"}
- self.function_map.update([(i, "h") for i in (3, 6, 16, 22, 23)])
- self.function_map.update([(i, "c") for i in (1, 5, 15)])
-
- def setUp(self):
- """Do setup."""
-
- def tearDown(self):
- """Clean up the test environment"""
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_str(self):
- """Test string."""
- unit = self.SQLunit()
- self.assertEqual(str(unit.slave), "Modbus Slave Context")
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_reset(self):
- """Test reset."""
- unit = self.SQLunit()
- unit.slave.reset()
-
- unit.slave._metadata.drop_all.assert_called_once_with() # pylint: disable=protected-access
- unit.slave._db_create.assert_called_once_with( # pylint: disable=protected-access
- unit.slave.table, unit.slave.database
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_validate_success(self):
- """Test validate success."""
- unit = self.SQLunit()
- unit.slave._connection.execute.return_value.fetchall.return_value = ( # pylint: disable=protected-access
- unit.mock_values
- )
- self.assertTrue(
- unit.slave.validate(
- unit.mock_function, unit.mock_addr, len(unit.mock_values)
- )
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_validate_failure(self):
- """Test validate failure."""
- unit = self.SQLunit()
- wrong_count = 9
- unit.slave._connection.execute.return_value.fetchall.return_value = ( # pylint: disable=protected-access
- unit.mock_values
- )
- self.assertFalse(
- unit.slave.validate(unit.mock_function, unit.mock_addr, wrong_count)
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_build_set(self):
- """Test build set."""
- unit = self.SQLunit()
- mock_set = [
- {"index": 0, "type": "h", "value": 11},
- {"index": 1, "type": "h", "value": 12},
- ]
- self.assertListEqual(
- unit.slave._build_set("h", 0, [11, 12]), # pylint: disable=protected-access
- mock_set,
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_check_success(self):
- """Test check success."""
- unit = self.SQLunit()
- mock_success_results = [1, 2, 3]
- unit.slave._get = MagicMock( # pylint: disable=protected-access
- return_value=mock_success_results
- )
- self.assertFalse(
- unit.slave._check("h", 0, 1) # pylint: disable=protected-access
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_check_failure(self):
- """Test check failure."""
- unit = self.SQLunit()
- mock_success_results = []
- unit.slave._get = MagicMock( # pylint: disable=protected-access
- return_value=mock_success_results
- )
- self.assertTrue(
- unit.slave._check("h", 0, 1) # pylint: disable=protected-access
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_get_values(self):
- """Test get values."""
- unit = self.SQLunit()
- unit.slave._get = MagicMock() # pylint: disable=protected-access
-
- for key, value in unit.function_map.items():
- unit.slave.getValues(key, unit.mock_addr, unit.mock_count)
- unit.slave._get.assert_called_with( # pylint: disable=protected-access
- value, unit.mock_addr + 1, unit.mock_count
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_set_values(self):
- """Test set values."""
- unit = self.SQLunit()
- unit.slave._set = MagicMock() # pylint: disable=protected-access
-
- for key, value in unit.function_map.items():
- unit.slave.setValues(key, unit.mock_addr, unit.mock_values, update=False)
- unit.slave._set.assert_called_with( # pylint: disable=protected-access
- value, unit.mock_addr + 1, unit.mock_values
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_set(self):
- """Test set."""
- unit = self.SQLunit()
- unit.slave._check = MagicMock( # pylint: disable=protected-access
- return_value=True
- )
- unit.slave._connection.execute = MagicMock( # pylint: disable=protected-access
- return_value=MockSqlResult(rowcount=len(unit.mock_values))
- )
- self.assertTrue(
- unit.slave._set( # pylint: disable=protected-access
- unit.mock_type, unit.mock_offset, unit.mock_values
- )
- )
-
- unit.slave._check = MagicMock( # pylint: disable=protected-access
- return_value=False
- )
- self.assertFalse(
- unit.slave._set( # pylint: disable=protected-access
- unit.mock_type, unit.mock_offset, unit.mock_values
- )
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_update_success(self):
- """Test update success."""
- unit = self.SQLunit()
- unit.slave._connection.execute = MagicMock( # pylint: disable=protected-access
- return_value=MockSqlResult(rowcount=len(unit.mock_values))
- )
- self.assertTrue(
- unit.slave._update( # pylint: disable=protected-access
- unit.mock_type, unit.mock_offset, unit.mock_values
- )
- )
-
- @pytest.mark.skip
- @pytest.mark.xdist_group(name="sql")
- def test_update_failure(self):
- """Test update failure."""
- unit = self.SQLunit()
- unit.slave._connection.execute = MagicMock( # pylint: disable=protected-access
- return_value=MockSqlResult(rowcount=100)
- )
- self.assertFalse(
- unit.slave._update( # pylint: disable=protected-access
- unit.mock_type, unit.mock_offset, unit.mock_values
- )
- )
diff --git a/test/test_device.py b/test/test_device.py
index d32ab3506..970fb7ffd 100644
--- a/test/test_device.py
+++ b/test/test_device.py
@@ -1,6 +1,4 @@
"""Test device."""
-import unittest
-
from pymodbus.constants import DeviceInformation
from pymodbus.device import (
DeviceInformationFactory,
@@ -16,14 +14,18 @@
# ---------------------------------------------------------------------------#
-class SimpleDataStoreTest(unittest.TestCase):
+class TestDataStore:
"""Unittest for the pymodbus.device module."""
# -----------------------------------------------------------------------#
# Setup/TearDown
# -----------------------------------------------------------------------#
- def setUp(self):
+ info = None
+ ident = None
+ control = None
+
+ def setup_method(self):
"""Do setup."""
self.info = {
0x00: "Bashwork", # VendorName
@@ -32,7 +34,7 @@ def setUp(self):
0x03: "http://internets.com", # VendorUrl
0x04: "pymodbus", # ProductName
0x05: "bashwork", # ModelName
- 0x06: "unittest", # UserApplicationName
+ 0x06: "pytest", # UserApplicationName
0x07: "x", # reserved
0x08: "x", # reserved
0x10: "reserved", # reserved
@@ -44,21 +46,16 @@ def setUp(self):
self.control = ModbusControlBlock()
self.control.reset()
- def tearDown(self):
- """Clean up the test environment"""
- del self.ident
- del self.control
-
def test_update_identity(self):
"""Test device identification reading"""
self.control.Identity.update(self.ident)
- self.assertEqual(self.control.Identity.VendorName, "Bashwork")
- self.assertEqual(self.control.Identity.ProductCode, "PTM")
- self.assertEqual(self.control.Identity.MajorMinorRevision, "1.0")
- self.assertEqual(self.control.Identity.VendorUrl, "http://internets.com")
- self.assertEqual(self.control.Identity.ProductName, "pymodbus")
- self.assertEqual(self.control.Identity.ModelName, "bashwork")
- self.assertEqual(self.control.Identity.UserApplicationName, "unittest")
+ assert self.control.Identity.VendorName == "Bashwork"
+ assert self.control.Identity.ProductCode == "PTM"
+ assert self.control.Identity.MajorMinorRevision == "1.0"
+ assert self.control.Identity.VendorUrl == "http://internets.com"
+ assert self.control.Identity.ProductName == "pymodbus"
+ assert self.control.Identity.ModelName == "bashwork"
+ assert self.control.Identity.UserApplicationName == "pytest"
def test_device_identification_factory(self):
"""Test device identification reading"""
@@ -66,107 +63,119 @@ def test_device_identification_factory(self):
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Specific, 0x00
)
- self.assertEqual(result[0x00], "Bashwork")
+ assert result[0x00] == "Bashwork"
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Basic, 0x00
)
- self.assertEqual(result[0x00], "Bashwork")
- self.assertEqual(result[0x01], "PTM")
- self.assertEqual(result[0x02], "1.0")
+ assert result[0x00] == "Bashwork"
+ assert result[0x01] == "PTM"
+ assert result[0x02] == "1.0"
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Regular, 0x00
)
- self.assertEqual(result[0x00], "Bashwork")
- self.assertEqual(result[0x01], "PTM")
- self.assertEqual(result[0x02], "1.0")
- self.assertEqual(result[0x03], "http://internets.com")
- self.assertEqual(result[0x04], "pymodbus")
- self.assertEqual(result[0x05], "bashwork")
- self.assertEqual(result[0x06], "unittest")
+ assert result[0x00] == "Bashwork"
+ assert result[0x01] == "PTM"
+ assert result[0x02] == "1.0"
+ assert result[0x03] == "http://internets.com"
+ assert result[0x04] == "pymodbus"
+ assert result[0x05] == "bashwork"
+ assert result[0x06] == "pytest"
def test_device_identification_factory_lookup(self):
"""Test device identification factory lookup."""
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Basic, 0x00
)
- self.assertEqual(sorted(result.keys()), [0x00, 0x01, 0x02])
+ assert sorted(result.keys()) == [0x00, 0x01, 0x02]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Basic, 0x02
)
- self.assertEqual(sorted(result.keys()), [0x02])
+ assert sorted(result.keys()) == [0x02]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Regular, 0x00
)
- self.assertEqual(
- sorted(result.keys()), [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
- )
+ assert sorted(result.keys()) == [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Regular, 0x01
)
- self.assertEqual(sorted(result.keys()), [0x01, 0x02, 0x03, 0x04, 0x05, 0x06])
+ assert sorted(result.keys()) == [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Regular, 0x05
)
- self.assertEqual(sorted(result.keys()), [0x05, 0x06])
+ assert sorted(result.keys()) == [0x05, 0x06]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x00
)
- self.assertEqual(
- sorted(result.keys()),
- [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x80, 0x82, 0xFF],
- )
+ assert sorted(result.keys()) == [
+ 0x00,
+ 0x01,
+ 0x02,
+ 0x03,
+ 0x04,
+ 0x05,
+ 0x06,
+ 0x80,
+ 0x82,
+ 0xFF,
+ ]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x02
)
- self.assertEqual(
- sorted(result.keys()), [0x02, 0x03, 0x04, 0x05, 0x06, 0x80, 0x82, 0xFF]
- )
+ assert sorted(result.keys()) == [0x02, 0x03, 0x04, 0x05, 0x06, 0x80, 0x82, 0xFF]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x06
)
- self.assertEqual(sorted(result.keys()), [0x06, 0x80, 0x82, 0xFF])
+ assert sorted(result.keys()) == [0x06, 0x80, 0x82, 0xFF]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x80
)
- self.assertEqual(sorted(result.keys()), [0x80, 0x82, 0xFF])
+ assert sorted(result.keys()) == [0x80, 0x82, 0xFF]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x82
)
- self.assertEqual(sorted(result.keys()), [0x82, 0xFF])
+ assert sorted(result.keys()) == [0x82, 0xFF]
result = DeviceInformationFactory.get(
self.control, DeviceInformation.Extended, 0x81
)
- self.assertEqual(
- sorted(result.keys()),
- [0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x80, 0x82, 0xFF],
- )
+ assert sorted(result.keys()) == [
+ 0x00,
+ 0x01,
+ 0x02,
+ 0x03,
+ 0x04,
+ 0x05,
+ 0x06,
+ 0x80,
+ 0x82,
+ 0xFF,
+ ]
def test_basic_commands(self):
"""Test device identification reading"""
- self.assertEqual(str(self.ident), "DeviceIdentity")
- self.assertEqual(str(self.control), "ModbusControl")
+ assert str(self.ident) == "DeviceIdentity"
+ assert str(self.control) == "ModbusControl"
def test_modbus_device_identification_get(self):
"""Test device identification reading"""
- self.assertEqual(self.ident[0x00], "Bashwork")
- self.assertEqual(self.ident[0x01], "PTM")
- self.assertEqual(self.ident[0x02], "1.0")
- self.assertEqual(self.ident[0x03], "http://internets.com")
- self.assertEqual(self.ident[0x04], "pymodbus")
- self.assertEqual(self.ident[0x05], "bashwork")
- self.assertEqual(self.ident[0x06], "unittest")
- self.assertNotEqual(self.ident[0x07], "x")
- self.assertNotEqual(self.ident[0x08], "x")
- self.assertNotEqual(self.ident[0x10], "reserved")
- self.assertEqual(self.ident[0x54], "")
+ assert self.ident[0x00] == "Bashwork"
+ assert self.ident[0x01] == "PTM"
+ assert self.ident[0x02] == "1.0"
+ assert self.ident[0x03] == "http://internets.com"
+ assert self.ident[0x04] == "pymodbus"
+ assert self.ident[0x05] == "bashwork"
+ assert self.ident[0x06] == "pytest"
+ assert self.ident[0x07] != "x"
+ assert self.ident[0x08] != "x"
+ assert self.ident[0x10] != "reserved"
+ assert not self.ident[0x54]
def test_modbus_device_identification_summary(self):
"""Test device identification summary creation"""
summary = sorted(self.ident.summary().values())
expected = sorted(list(self.info.values())[:0x07]) # remove private
- self.assertEqual(summary, expected)
+ assert summary == expected
def test_modbus_device_identification_set(self):
"""Test a device identification writing"""
@@ -175,31 +184,31 @@ def test_modbus_device_identification_set(self):
self.ident[0x10] = "public"
self.ident[0x54] = "testing"
- self.assertNotEqual("y", self.ident[0x07])
- self.assertNotEqual("y", self.ident[0x08])
- self.assertEqual("public", self.ident[0x10])
- self.assertEqual("testing", self.ident[0x54])
+ assert self.ident[0x07] != "y"
+ assert self.ident[0x08] != "y"
+ assert self.ident[0x10] == "public"
+ assert self.ident[0x54] == "testing"
def test_modbus_control_block_ascii_modes(self):
"""Test a server control block ascii mode"""
- self.assertEqual(id(self.control), id(ModbusControlBlock()))
+ assert id(self.control) == id(ModbusControlBlock())
self.control.Mode = "RTU"
- self.assertEqual("RTU", self.control.Mode)
+ assert self.control.Mode == "RTU"
self.control.Mode = "FAKE"
- self.assertNotEqual("FAKE", self.control.Mode)
+ assert self.control.Mode != "FAKE"
def test_modbus_control_block_counters(self):
"""Tests the MCB counters methods"""
- self.assertEqual(0x0, self.control.Counter.BusMessage)
+ assert not self.control.Counter.BusMessage
for _ in range(10):
self.control.Counter.BusMessage += 1
self.control.Counter.SlaveMessage += 1
- self.assertEqual(10, self.control.Counter.BusMessage)
+ assert self.control.Counter.BusMessage == 10
self.control.Counter.BusMessage = 0x00
- self.assertEqual(0, self.control.Counter.BusMessage)
- self.assertEqual(10, self.control.Counter.SlaveMessage)
+ assert not self.control.Counter.BusMessage
+ assert self.control.Counter.SlaveMessage == 10
self.control.Counter.reset()
- self.assertEqual(0, self.control.Counter.SlaveMessage)
+ assert not self.control.Counter.SlaveMessage
def test_modbus_control_block_update(self):
"""Tests the MCB counters update methods"""
@@ -207,96 +216,96 @@ def test_modbus_control_block_update(self):
self.control.Counter.BusMessage += 1
self.control.Counter.SlaveMessage += 1
self.control.Counter.update(values)
- self.assertEqual(6, self.control.Counter.SlaveMessage)
- self.assertEqual(6, self.control.Counter.BusMessage)
+ assert self.control.Counter.SlaveMessage == 6
+ assert self.control.Counter.BusMessage == 6
def test_modbus_control_block_iterator(self):
"""Tests the MCB counters iterator"""
self.control.Counter.reset()
for _, count in self.control:
- self.assertEqual(0, count)
+ assert not count
def test_modbus_counters_handler_iterator(self):
"""Tests the MCB counters iterator"""
self.control.Counter.reset()
for _, count in self.control.Counter:
- self.assertEqual(0, count)
+ assert not count
def test_modbus_control_block_counter_summary(self):
"""Tests retrieving the current counter summary"""
- self.assertEqual(0x00, self.control.Counter.summary())
+ assert not self.control.Counter.summary()
for _ in range(10):
self.control.Counter.BusMessage += 1
self.control.Counter.SlaveMessage += 1
self.control.Counter.SlaveNAK += 1
self.control.Counter.BusCharacterOverrun += 1
- self.assertEqual(0xA9, self.control.Counter.summary())
+ assert self.control.Counter.summary() == 0xA9
self.control.Counter.reset()
- self.assertEqual(0x00, self.control.Counter.summary())
+ assert not self.control.Counter.summary()
def test_modbus_control_block_listen(self):
"""Test the MCB listen flag methods"""
self.control.ListenOnly = False
- self.assertEqual(self.control.ListenOnly, False)
+ assert not self.control.ListenOnly
self.control.ListenOnly = not self.control.ListenOnly
- self.assertEqual(self.control.ListenOnly, True)
+ assert self.control.ListenOnly
def test_modbus_control_block_delimiter(self):
"""Tests the MCB delimiter setting methods"""
self.control.Delimiter = b"\r"
- self.assertEqual(self.control.Delimiter, b"\r")
+ assert self.control.Delimiter == b"\r"
self.control.Delimiter = "="
- self.assertEqual(self.control.Delimiter, b"=")
+ assert self.control.Delimiter == b"="
self.control.Delimiter = 61
- self.assertEqual(self.control.Delimiter, b"=")
+ assert self.control.Delimiter == b"="
def test_modbus_control_block_diagnostic(self):
"""Tests the MCB delimiter setting methods"""
- self.assertEqual([False] * 16, self.control.getDiagnosticRegister())
+ assert self.control.getDiagnosticRegister() == [False] * 16
for i in (1, 3, 4, 6):
self.control.setDiagnostic({i: True})
- self.assertEqual(True, self.control.getDiagnostic(1))
- self.assertEqual(False, self.control.getDiagnostic(2))
+ assert self.control.getDiagnostic(1)
+ assert not self.control.getDiagnostic(2)
actual = [False, True, False, True, True, False, True] + [False] * 9
- self.assertEqual(actual, self.control.getDiagnosticRegister())
+ assert actual == self.control.getDiagnosticRegister()
for i in range(16):
self.control.setDiagnostic({i: False})
def test_modbus_control_block_invalid_diagnostic(self):
"""Tests querying invalid MCB counters methods"""
- self.assertEqual(None, self.control.getDiagnostic(-1))
- self.assertEqual(None, self.control.getDiagnostic(17))
- self.assertEqual(None, self.control.getDiagnostic(None))
- self.assertEqual(None, self.control.getDiagnostic([1, 2, 3]))
+ assert not self.control.getDiagnostic(-1)
+ assert not self.control.getDiagnostic(17)
+ assert not self.control.getDiagnostic(None)
+ assert not self.control.getDiagnostic([1, 2, 3])
def test_clearing_control_events(self):
"""Test adding and clearing modbus events"""
- self.assertEqual(self.control.Events, [])
+ assert self.control.Events == []
event = ModbusEvent()
self.control.addEvent(event)
- self.assertEqual(self.control.Events, [event])
- self.assertEqual(self.control.Counter.Event, 1)
+ assert self.control.Events == [event]
+ assert self.control.Counter.Event == 1
self.control.clearEvents()
- self.assertEqual(self.control.Events, [])
- self.assertEqual(self.control.Counter.Event, 1)
+ assert self.control.Events == []
+ assert self.control.Counter.Event == 1
def test_retrieving_control_events(self):
"""Test adding and removing a host"""
- self.assertEqual(self.control.Events, [])
+ assert self.control.Events == []
event = RemoteReceiveEvent()
self.control.addEvent(event)
- self.assertEqual(self.control.Events, [event])
+ assert self.control.Events == [event]
packet = self.control.getEvents()
- self.assertEqual(packet, b"\x40")
+ assert packet == b"\x40"
def test_modbus_plus_statistics(self):
"""Test device identification reading"""
default = [0x0000] * 55
statistics = ModbusPlusStatistics()
- self.assertEqual(default, statistics.encode())
+ assert default == statistics.encode()
statistics.reset()
- self.assertEqual(default, statistics.encode())
- self.assertEqual(default, self.control.Plus.encode())
+ assert default == statistics.encode()
+ assert default == self.control.Plus.encode()
def test_modbus_plus_statistics_helpers(self):
"""Test modbus plus statistics helper methods"""
@@ -351,5 +360,5 @@ def test_modbus_plus_statistics_helpers(self):
[0, 0, 0, 0, 0, 0, 0, 0],
]
stats_summary = list(statistics.summary())
- self.assertEqual(sorted(summary), sorted(stats_summary))
- self.assertEqual(0x00, sum(sum(value[1]) for value in statistics))
+ assert sorted(summary) == sorted(stats_summary)
+ assert not sum(sum(value[1]) for value in statistics)
diff --git a/test/test_diag_messages.py b/test/test_diag_messages.py
index 526bfe29f..dda7d7a87 100644
--- a/test/test_diag_messages.py
+++ b/test/test_diag_messages.py
@@ -1,5 +1,5 @@
"""Test diag messages."""
-import unittest
+import pytest
from pymodbus.constants import ModbusPlusOperation
from pymodbus.diag_message import (
@@ -45,103 +45,105 @@
from pymodbus.exceptions import NotImplementedException
-class SimpleDataStoreTest(unittest.TestCase):
+class TestDataStore:
"""Unittest for the pymodbus.diag_message module."""
- def setUp(self):
- """Do setup."""
- self.requests = [
- (
- RestartCommunicationsOptionRequest,
- b"\x00\x01\x00\x00",
- b"\x00\x01\xff\x00",
- ),
- (ReturnDiagnosticRegisterRequest, b"\x00\x02\x00\x00", b"\x00\x02\x00\x00"),
- (
- ChangeAsciiInputDelimiterRequest,
- b"\x00\x03\x00\x00",
- b"\x00\x03\x00\x00",
- ),
- (ForceListenOnlyModeRequest, b"\x00\x04\x00\x00", b"\x00\x04"),
- (ReturnQueryDataRequest, b"\x00\x00\x00\x00", b"\x00\x00\x00\x00"),
- (ClearCountersRequest, b"\x00\x0a\x00\x00", b"\x00\x0a\x00\x00"),
- (ReturnBusMessageCountRequest, b"\x00\x0b\x00\x00", b"\x00\x0b\x00\x00"),
- (
- ReturnBusCommunicationErrorCountRequest,
- b"\x00\x0c\x00\x00",
- b"\x00\x0c\x00\x00",
- ),
- (
- ReturnBusExceptionErrorCountRequest,
- b"\x00\x0d\x00\x00",
- b"\x00\x0d\x00\x00",
- ),
- (ReturnSlaveMessageCountRequest, b"\x00\x0e\x00\x00", b"\x00\x0e\x00\x00"),
- (
- ReturnSlaveNoResponseCountRequest,
- b"\x00\x0f\x00\x00",
- b"\x00\x0f\x00\x00",
- ),
- (ReturnSlaveNAKCountRequest, b"\x00\x10\x00\x00", b"\x00\x10\x00\x00"),
- (ReturnSlaveBusyCountRequest, b"\x00\x11\x00\x00", b"\x00\x11\x00\x00"),
- (
- ReturnSlaveBusCharacterOverrunCountRequest,
- b"\x00\x12\x00\x00",
- b"\x00\x12\x00\x00",
- ),
- (ReturnIopOverrunCountRequest, b"\x00\x13\x00\x00", b"\x00\x13\x00\x00"),
- (ClearOverrunCountRequest, b"\x00\x14\x00\x00", b"\x00\x14\x00\x00"),
- (
- GetClearModbusPlusRequest,
- b"\x00\x15\x00\x00",
- b"\x00\x15\x00\x00" + b"\x00\x00" * 55,
- ),
- ]
-
- self.responses = [
- # (DiagnosticStatusResponse, b"\x00\x00\x00\x00"),
- # (DiagnosticStatusSimpleResponse, b"\x00\x00\x00\x00"),
- (ReturnQueryDataResponse, b"\x00\x00\x00\x00"),
- (RestartCommunicationsOptionResponse, b"\x00\x01\x00\x00"),
- (ReturnDiagnosticRegisterResponse, b"\x00\x02\x00\x00"),
- (ChangeAsciiInputDelimiterResponse, b"\x00\x03\x00\x00"),
- (ForceListenOnlyModeResponse, b"\x00\x04"),
- (ReturnQueryDataResponse, b"\x00\x00\x00\x00"),
- (ClearCountersResponse, b"\x00\x0a\x00\x00"),
- (ReturnBusMessageCountResponse, b"\x00\x0b\x00\x00"),
- (ReturnBusCommunicationErrorCountResponse, b"\x00\x0c\x00\x00"),
- (ReturnBusExceptionErrorCountResponse, b"\x00\x0d\x00\x00"),
- (ReturnSlaveMessageCountResponse, b"\x00\x0e\x00\x00"),
- (ReturnSlaveNoResponseCountResponse, b"\x00\x0f\x00\x00"),
- (ReturnSlaveNAKCountResponse, b"\x00\x10\x00\x00"),
- (ReturnSlaveBusyCountResponse, b"\x00\x11\x00\x00"),
- (ReturnSlaveBusCharacterOverrunCountResponse, b"\x00\x12\x00\x00"),
- (ReturnIopOverrunCountResponse, b"\x00\x13\x00\x00"),
- (ClearOverrunCountResponse, b"\x00\x14\x00\x00"),
- (GetClearModbusPlusResponse, b"\x00\x15\x00\x04" + b"\x00\x00" * 55),
- ]
-
- def tearDown(self):
- """Clean up the test environment"""
- del self.requests
- del self.responses
+ requests = [
+ (
+ RestartCommunicationsOptionRequest,
+ b"\x00\x01\x00\x00",
+ b"\x00\x01\xff\x00",
+ ),
+ (ReturnDiagnosticRegisterRequest, b"\x00\x02\x00\x00", b"\x00\x02\x00\x00"),
+ (
+ ChangeAsciiInputDelimiterRequest,
+ b"\x00\x03\x00\x00",
+ b"\x00\x03\x00\x00",
+ ),
+ (ForceListenOnlyModeRequest, b"\x00\x04\x00\x00", b"\x00\x04"),
+ (ReturnQueryDataRequest, b"\x00\x00\x00\x00", b"\x00\x00\x00\x00"),
+ (ClearCountersRequest, b"\x00\x0a\x00\x00", b"\x00\x0a\x00\x00"),
+ (ReturnBusMessageCountRequest, b"\x00\x0b\x00\x00", b"\x00\x0b\x00\x00"),
+ (
+ ReturnBusCommunicationErrorCountRequest,
+ b"\x00\x0c\x00\x00",
+ b"\x00\x0c\x00\x00",
+ ),
+ (
+ ReturnBusExceptionErrorCountRequest,
+ b"\x00\x0d\x00\x00",
+ b"\x00\x0d\x00\x00",
+ ),
+ (ReturnSlaveMessageCountRequest, b"\x00\x0e\x00\x00", b"\x00\x0e\x00\x00"),
+ (
+ ReturnSlaveNoResponseCountRequest,
+ b"\x00\x0f\x00\x00",
+ b"\x00\x0f\x00\x00",
+ ),
+ (ReturnSlaveNAKCountRequest, b"\x00\x10\x00\x00", b"\x00\x10\x00\x00"),
+ (ReturnSlaveBusyCountRequest, b"\x00\x11\x00\x00", b"\x00\x11\x00\x00"),
+ (
+ ReturnSlaveBusCharacterOverrunCountRequest,
+ b"\x00\x12\x00\x00",
+ b"\x00\x12\x00\x00",
+ ),
+ (ReturnIopOverrunCountRequest, b"\x00\x13\x00\x00", b"\x00\x13\x00\x00"),
+ (ClearOverrunCountRequest, b"\x00\x14\x00\x00", b"\x00\x14\x00\x00"),
+ (
+ GetClearModbusPlusRequest,
+ b"\x00\x15\x00\x00",
+ b"\x00\x15\x00\x00" + b"\x00\x00" * 55,
+ ),
+ ]
+
+ responses = [
+ # (DiagnosticStatusResponse, b"\x00\x00\x00\x00"),
+ # (DiagnosticStatusSimpleResponse, b"\x00\x00\x00\x00"),
+ (ReturnQueryDataResponse, b"\x00\x00\x00\x00"),
+ (RestartCommunicationsOptionResponse, b"\x00\x01\x00\x00"),
+ (ReturnDiagnosticRegisterResponse, b"\x00\x02\x00\x00"),
+ (ChangeAsciiInputDelimiterResponse, b"\x00\x03\x00\x00"),
+ (ForceListenOnlyModeResponse, b"\x00\x04"),
+ (ReturnQueryDataResponse, b"\x00\x00\x00\x00"),
+ (ClearCountersResponse, b"\x00\x0a\x00\x00"),
+ (ReturnBusMessageCountResponse, b"\x00\x0b\x00\x00"),
+ (ReturnBusCommunicationErrorCountResponse, b"\x00\x0c\x00\x00"),
+ (ReturnBusExceptionErrorCountResponse, b"\x00\x0d\x00\x00"),
+ (ReturnSlaveMessageCountResponse, b"\x00\x0e\x00\x00"),
+ (ReturnSlaveNoResponseCountResponse, b"\x00\x0f\x00\x00"),
+ (ReturnSlaveNAKCountResponse, b"\x00\x10\x00\x00"),
+ (ReturnSlaveBusyCountResponse, b"\x00\x11\x00\x00"),
+ (ReturnSlaveBusCharacterOverrunCountResponse, b"\x00\x12\x00\x00"),
+ (ReturnIopOverrunCountResponse, b"\x00\x13\x00\x00"),
+ (ClearOverrunCountResponse, b"\x00\x14\x00\x00"),
+ (GetClearModbusPlusResponse, b"\x00\x15\x00\x04" + b"\x00\x00" * 55),
+ ]
+
+ def test_diagnostic_encode_decode(self):
+ """Testing diagnostic request/response can be decoded and encoded."""
+ for msg in (DiagnosticStatusRequest, DiagnosticStatusResponse):
+ msg_obj = msg()
+ data = b"\x00\x01\x02\x03"
+ msg_obj.decode(data)
+ result = msg_obj.encode()
+ assert data == result
def test_diagnostic_requests_decode(self):
"""Testing diagnostic request messages encoding"""
for msg, enc, _ in self.requests:
handle = DiagnosticStatusRequest()
handle.decode(enc)
- self.assertEqual(handle.sub_function_code, msg.sub_function_code)
+ assert handle.sub_function_code == msg.sub_function_code
+ encoded = handle.encode()
+ assert enc == encoded
def test_diagnostic_simple_requests(self):
"""Testing diagnostic request messages encoding"""
request = DiagnosticStatusSimpleRequest(b"\x12\x34")
request.sub_function_code = 0x1234
- self.assertRaises(
- NotImplementedException,
- lambda: request.execute(), # pylint: disable=unnecessary-lambda
- )
- self.assertEqual(request.encode(), b"\x12\x34\x12\x34")
+ with pytest.raises(NotImplementedException):
+ request.execute()
+ assert request.encode() == b"\x12\x34\x12\x34"
DiagnosticStatusSimpleResponse(None)
def test_diagnostic_response_decode(self):
@@ -149,52 +151,52 @@ def test_diagnostic_response_decode(self):
for msg, enc, _ in self.requests:
handle = DiagnosticStatusResponse()
handle.decode(enc)
- self.assertEqual(handle.sub_function_code, msg.sub_function_code)
+ assert handle.sub_function_code == msg.sub_function_code
def test_diagnostic_requests_encode(self):
"""Testing diagnostic request messages encoding"""
for msg, enc, _ in self.requests:
- self.assertEqual(msg().encode(), enc)
+ assert msg().encode() == enc
def test_diagnostic_execute(self):
"""Testing diagnostic message execution"""
for message, encoded, executed in self.requests:
encoded = message().execute().encode()
- self.assertEqual(encoded, executed)
+ assert encoded == executed
def test_return_query_data_request(self):
"""Testing diagnostic message execution"""
message = ReturnQueryDataRequest([0x0000] * 2)
- self.assertEqual(message.encode(), b"\x00\x00\x00\x00\x00\x00")
+ assert message.encode() == b"\x00\x00\x00\x00\x00\x00"
message = ReturnQueryDataRequest(0x0000)
- self.assertEqual(message.encode(), b"\x00\x00\x00\x00")
+ assert message.encode() == b"\x00\x00\x00\x00"
def test_return_query_data_response(self):
"""Testing diagnostic message execution"""
message = ReturnQueryDataResponse([0x0000] * 2)
- self.assertEqual(message.encode(), b"\x00\x00\x00\x00\x00\x00")
+ assert message.encode() == b"\x00\x00\x00\x00\x00\x00"
message = ReturnQueryDataResponse(0x0000)
- self.assertEqual(message.encode(), b"\x00\x00\x00\x00")
+ assert message.encode() == b"\x00\x00\x00\x00"
def test_restart_cmmunications_option(self):
"""Testing diagnostic message execution"""
request = RestartCommunicationsOptionRequest(True)
- self.assertEqual(request.encode(), b"\x00\x01\xff\x00")
+ assert request.encode() == b"\x00\x01\xff\x00"
request = RestartCommunicationsOptionRequest(False)
- self.assertEqual(request.encode(), b"\x00\x01\x00\x00")
+ assert request.encode() == b"\x00\x01\x00\x00"
response = RestartCommunicationsOptionResponse(True)
- self.assertEqual(response.encode(), b"\x00\x01\xff\x00")
+ assert response.encode() == b"\x00\x01\xff\x00"
response = RestartCommunicationsOptionResponse(False)
- self.assertEqual(response.encode(), b"\x00\x01\x00\x00")
+ assert response.encode() == b"\x00\x01\x00\x00"
def test_get_clear_modbus_plus_request_execute(self):
"""Testing diagnostic message execution"""
request = GetClearModbusPlusRequest(data=ModbusPlusOperation.ClearStatistics)
response = request.execute()
- self.assertEqual(response.message, ModbusPlusOperation.ClearStatistics)
+ assert response.message == ModbusPlusOperation.ClearStatistics
request = GetClearModbusPlusRequest(data=ModbusPlusOperation.GetStatistics)
response = request.execute()
resp = [ModbusPlusOperation.GetStatistics]
- self.assertEqual(response.message, resp + [0x00] * 55)
+ assert response.message == resp + [0x00] * 55
diff --git a/test/test_events.py b/test/test_events.py
index 7eac3de03..bfe68bda9 100644
--- a/test/test_events.py
+++ b/test/test_events.py
@@ -1,5 +1,5 @@
"""Test events."""
-import unittest
+import pytest
from pymodbus.events import (
CommunicationRestartEvent,
@@ -11,41 +11,37 @@
from pymodbus.exceptions import NotImplementedException, ParameterException
-class ModbusEventsTest(unittest.TestCase):
+class TestEvents:
"""Unittest for the pymodbus.device module."""
- def setUp(self):
- """Set up the test environment"""
-
- def tearDown(self):
- """Clean up the test environment"""
-
def test_modbus_event_base_class(self):
"""Test modbus event base class."""
event = ModbusEvent()
- self.assertRaises(NotImplementedException, event.encode)
- self.assertRaises(NotImplementedException, lambda: event.decode(None))
+ with pytest.raises(NotImplementedException):
+ event.encode()
+ with pytest.raises(NotImplementedException):
+ event.decode(None)
def test_remote_receive_event(self):
"""Test remove receive event."""
event = RemoteReceiveEvent()
event.decode(b"\x70")
- self.assertTrue(event.overrun)
- self.assertTrue(event.listen)
- self.assertTrue(event.broadcast)
+ assert event.overrun
+ assert event.listen
+ assert event.broadcast
def test_remote_sent_event(self):
"""Test remote sent event."""
event = RemoteSendEvent()
result = event.encode()
- self.assertEqual(result, b"\x40")
+ assert result == b"\x40"
event.decode(b"\x7f")
- self.assertTrue(event.read)
- self.assertTrue(event.slave_abort)
- self.assertTrue(event.slave_busy)
- self.assertTrue(event.slave_nak)
- self.assertTrue(event.write_timeout)
- self.assertTrue(event.listen)
+ assert event.read
+ assert event.slave_abort
+ assert event.slave_busy
+ assert event.slave_nak
+ assert event.write_timeout
+ assert event.listen
def test_remote_sent_event_encode(self):
"""Test remote sent event encode."""
@@ -59,22 +55,24 @@ def test_remote_sent_event_encode(self):
}
event = RemoteSendEvent(**arguments)
result = event.encode()
- self.assertEqual(result, b"\x7f")
+ assert result == b"\x7f"
def test_entered_listen_mode_event(self):
"""Test entered listen mode event."""
event = EnteredListenModeEvent()
result = event.encode()
- self.assertEqual(result, b"\x04")
+ assert result == b"\x04"
event.decode(b"\x04")
- self.assertEqual(event.value, 0x04)
- self.assertRaises(ParameterException, lambda: event.decode(b"\x00"))
+ assert event.value == 0x04
+ with pytest.raises(ParameterException):
+ event.decode(b"\x00")
def test_communication_restart_event(self):
"""Test communication restart event."""
event = CommunicationRestartEvent()
result = event.encode()
- self.assertEqual(result, b"\x00")
+ assert result == b"\x00"
event.decode(b"\x00")
- self.assertEqual(event.value, 0x00)
- self.assertRaises(ParameterException, lambda: event.decode(b"\x04"))
+ assert not event.value
+ with pytest.raises(ParameterException):
+ event.decode(b"\x04")
diff --git a/test/test_example_client_server.py b/test/test_example_client_server.py
new file mode 100755
index 000000000..e80cd9436
--- /dev/null
+++ b/test/test_example_client_server.py
@@ -0,0 +1,133 @@
+"""Test example server/client sync/async
+
+This is a thorough test of the generic examples
+(in principle examples that are used in other
+examples, like run a server).
+"""
+import asyncio
+import logging
+from threading import Thread
+from time import sleep
+
+import pytest
+import pytest_asyncio
+
+from examples.client_async import run_async_client, setup_async_client
+from examples.client_calls import run_sync_calls
+from examples.client_sync import run_sync_client, setup_sync_client
+from examples.client_test import run_async_calls as run_async_simple_calls
+from examples.helper import get_commandline
+from examples.server_async import run_async_server, setup_server
+from examples.server_sync import run_sync_server
+from pymodbus import pymodbus_apply_logging_config
+from pymodbus.server import ServerAsyncStop, ServerStop
+
+
+_logger = logging.getLogger()
+_logger.setLevel("DEBUG")
+pymodbus_apply_logging_config("DEBUG")
+TEST_COMMS_FRAMER = [
+ ("tcp", "socket", 5020),
+ ("tcp", "rtu", 5020),
+ ("tls", "tls", 5020),
+ ("udp", "socket", 5020),
+ ("udp", "rtu", 5020),
+ ("serial", "rtu", "socket://127.0.0.1:5020"),
+ # awaiting fix: ("serial", "ascii", "socket://127.0.0.1:5020"),
+ # awaiting fix: ("serial", "binary", "socket://127.0.0.1:5020"),
+]
+
+
+@pytest_asyncio.fixture(name="mock_run_server")
+async def _helper_server(
+ test_comm,
+ test_framer,
+ test_port,
+):
+ """Run server."""
+ cmdline = [
+ "--comm",
+ test_comm,
+ "--port",
+ str(test_port),
+ "--framer",
+ test_framer,
+ "--baudrate",
+ "9600",
+ "--log",
+ "debug",
+ ]
+ run_args = setup_server(cmdline=cmdline)
+ task = asyncio.create_task(run_async_server(run_args))
+ await asyncio.sleep(0.1)
+ yield
+ await ServerAsyncStop()
+ task.cancel()
+ await task
+
+
+def test_get_commandline():
+ """Test helper get_commandline()"""
+ args = get_commandline(cmdline=["--log", "info"])
+ assert args.log == "info"
+ assert args.host == "127.0.0.1"
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+@pytest.mark.parametrize(("test_comm", "test_framer", "test_port"), TEST_COMMS_FRAMER)
+async def test_exp_async_server_client(
+ test_comm,
+ test_framer,
+ test_port,
+ mock_run_server,
+):
+ """Run async client and server."""
+ assert not mock_run_server
+ cmdline = [
+ "--comm",
+ test_comm,
+ "--host",
+ "127.0.0.1",
+ "--framer",
+ test_framer,
+ "--port",
+ str(test_port),
+ "--baudrate",
+ "9600",
+ "--log",
+ "debug",
+ ]
+ test_client = setup_async_client(cmdline=cmdline)
+ await run_async_client(test_client, modbus_calls=run_async_simple_calls)
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+@pytest.mark.parametrize(
+ ("test_comm", "test_framer", "test_port"), [TEST_COMMS_FRAMER[0]]
+)
+def test_exp_sync_server_client(
+ test_comm,
+ test_framer,
+ test_port,
+):
+ """Run sync client and server."""
+ cmdline = [
+ "--comm",
+ test_comm,
+ "--port",
+ str(test_port),
+ "--baudrate",
+ "9600",
+ "--log",
+ "debug",
+ "--framer",
+ test_framer,
+ ]
+ run_args = setup_server(cmdline=cmdline)
+ thread = Thread(target=run_sync_server, args=(run_args,))
+ thread.daemon = True
+ thread.start()
+ sleep(1)
+ test_client = setup_sync_client(cmdline=cmdline)
+ run_sync_client(test_client, modbus_calls=run_sync_calls)
+ ServerStop()
diff --git a/test/test_examples.py b/test/test_examples.py
index 1fda68526..7803510f8 100755
--- a/test/test_examples.py
+++ b/test/test_examples.py
@@ -1,157 +1,182 @@
-"""Test client async."""
+"""Test examples to ensure they run
+
+the following are excluded:
+ client_async.py
+ client_calls.py
+ client_sync.py
+ helper.py
+ server_async.py
+ server_sync.py
+
+they represent generic examples and
+are tested in
+ test_example_client_server.py
+a lot more thoroughly.
+"""
import asyncio
import logging
-from threading import Thread
-from time import sleep
import pytest
import pytest_asyncio
+from examples.build_bcd_payload import BcdPayloadBuilder, BcdPayloadDecoder
from examples.client_async import run_async_client, setup_async_client
-from examples.client_calls import run_async_calls, run_sync_calls
+from examples.client_custom_msg import run_custom_client
from examples.client_payload import run_payload_calls
-from examples.client_sync import run_sync_client, setup_sync_client
-from examples.helper import Commandline
-
-# from examples.modbus_forwarder import run_forwarder
+from examples.client_test import run_async_calls as run_client_test
+from examples.message_generator import generate_messages
+from examples.message_parser import parse_messages
from examples.server_async import run_async_server, setup_server
+from examples.server_callback import run_callback_server
from examples.server_payload import setup_payload_server
-from examples.server_sync import run_sync_server
+from examples.server_simulator import run_server_simulator, setup_simulator
+from examples.server_updating import run_updating_server, setup_updating_server
from pymodbus import pymodbus_apply_logging_config
-from pymodbus.server import ServerAsyncStop, ServerStop
-from pymodbus.transaction import (
- ModbusAsciiFramer,
- ModbusBinaryFramer,
- ModbusRtuFramer,
- ModbusSocketFramer,
- ModbusTlsFramer,
-)
+from pymodbus.server import ServerAsyncStop
+
+
+# from examples.serial_forwarder import run_forwarder
_logger = logging.getLogger()
_logger.setLevel("DEBUG")
-TEST_COMMS_FRAMER = [
- ("tcp", ModbusSocketFramer, 5020),
- ("tcp", ModbusRtuFramer, 5021),
- ("tls", ModbusTlsFramer, 5020),
- ("udp", ModbusSocketFramer, 5020),
- ("udp", ModbusRtuFramer, 5021),
- ("serial", ModbusRtuFramer, 5020),
- ("serial", ModbusAsciiFramer, 5021),
- ("serial", ModbusBinaryFramer, 5022),
+pymodbus_apply_logging_config("DEBUG")
+
+
+CMDARGS = [
+ "--comm",
+ "tcp",
+ "--port",
+ "5020",
+ "--baudrate",
+ "9600",
+ "--log",
+ "debug",
+ "--framer",
+ "socket",
]
@pytest_asyncio.fixture(name="mock_run_server")
-async def _helper_server(
- test_comm,
- test_framer,
- test_port_offset,
- test_port,
-):
+async def _helper_server():
"""Run server."""
- if pytest.IS_WINDOWS and test_comm == "serial":
- yield
- return
- args = Commandline.copy()
- args.comm = test_comm
- args.framer = test_framer
- args.port = test_port + test_port_offset
- if test_comm == "serial":
- args.port = f"socket://127.0.0.1:{args.port}"
- run_args = setup_server(args)
- asyncio.create_task(run_async_server(run_args))
+ run_args = setup_server(cmdline=CMDARGS)
+ task = asyncio.create_task(run_async_server(run_args))
await asyncio.sleep(0.1)
yield
await ServerAsyncStop()
+ await asyncio.sleep(0.1)
+ task.cancel()
+ await task
+ await asyncio.sleep(0.1)
-async def run_client(test_comm, test_type, args=Commandline.copy()):
- """Help run async client."""
-
- args.comm = test_comm
- if test_comm == "serial":
- args.port = f"socket://127.0.0.1:{args.port}"
- test_client = setup_async_client(args=args)
- if not test_type:
- await run_async_client(test_client)
- else:
- await run_async_client(test_client, modbus_calls=test_type)
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_server_client_payload():
+ """Test server/client with payload."""
+ run_args = setup_payload_server(cmdline=CMDARGS)
+ task = asyncio.create_task(run_async_server(run_args))
await asyncio.sleep(0.1)
+ testclient = setup_async_client(cmdline=CMDARGS)
+ await run_async_client(testclient, modbus_calls=run_payload_calls)
+ await asyncio.sleep(0.1)
+ await ServerAsyncStop()
+ await asyncio.sleep(0.1)
+ task.cancel()
+ await task
-@pytest.mark.parametrize("test_port_offset", [10])
-@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
-async def test_exp_async_server_client(
- test_comm,
- test_framer,
- test_port_offset,
- test_port,
- mock_run_server,
-):
- """Run async client and server."""
- # JAN WAITING
- if pytest.IS_WINDOWS and test_comm == "serial":
- return
- if test_comm in {"tcp", "tls"}:
- return
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_client_test(mock_run_server):
+ """Test client used for fast testing."""
assert not mock_run_server
- args = Commandline.copy()
- args.framer = test_framer
- args.comm = test_comm
- args.port = test_port + test_port_offset
- await run_client(test_comm, None, args=args)
+ testclient = setup_async_client(cmdline=CMDARGS)
+ await run_async_client(testclient, modbus_calls=run_client_test)
-@pytest.mark.parametrize("test_port_offset", [20])
-@pytest.mark.parametrize("test_comm, test_framer, test_port", [TEST_COMMS_FRAMER[0]])
-def test_exp_sync_server_client(
- test_comm,
- test_framer,
- test_port_offset,
- test_port,
-):
- """Run sync client and server."""
- args = Commandline.copy()
- args.comm = test_comm
- args.port = test_port + test_port_offset
- args.framer = test_framer
- run_args = setup_server(args)
- thread = Thread(target=run_sync_server, args=(run_args,))
- thread.daemon = True
- thread.start()
- sleep(1)
- test_client = setup_sync_client(args=args)
- run_sync_client(test_client, modbus_calls=run_sync_calls)
- ServerStop()
-
-
-# JAN
-@pytest.mark.parametrize("test_port_offset", [30])
-@pytest.mark.parametrize("test_comm, test_framer, test_port", TEST_COMMS_FRAMER)
-async def xtest_exp_client_calls(
- test_comm,
- test_framer,
- test_port_offset,
- test_port,
- mock_run_server,
-):
- """Test client-server async with different framers and calls."""
+
+@pytest.mark.parametrize("framer", ["socket", "rtu"])
+async def test_exp_message_generator(framer):
+ """Test all message generator."""
+ generate_messages(cmdline=["--framer", framer])
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_server_simulator():
+ """Test server simulator."""
+ cmdargs = ["--log", "debug", "--port", "5020"]
+ run_args = setup_simulator(cmdline=cmdargs)
+ task = asyncio.create_task(run_server_simulator(run_args))
+ await asyncio.sleep(0.1)
+ testclient = setup_async_client(cmdline=CMDARGS)
+ await run_async_client(testclient, modbus_calls=run_client_test)
+ await asyncio.sleep(0.1)
+ await ServerAsyncStop()
+ await asyncio.sleep(0.1)
+ task.cancel()
+ await task
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_updating_server():
+ """Test server simulator."""
+ run_args = setup_updating_server(cmdline=CMDARGS)
+ task = asyncio.create_task(run_updating_server(run_args))
+ await asyncio.sleep(0.1)
+ testclient = setup_async_client(cmdline=CMDARGS)
+ await run_async_client(testclient, modbus_calls=run_client_test)
+ await asyncio.sleep(0.1)
+ await ServerAsyncStop()
+ await asyncio.sleep(0.1)
+ task.cancel()
+ await task
+
+
+def test_exp_build_bcd_payload():
+ """Test build bcd payload."""
+ builder = BcdPayloadBuilder()
+ decoder = BcdPayloadDecoder(builder)
+ assert str(decoder)
+
+
+def test_exp_message_parser():
+ """Test message parser."""
+ parse_messages(["--framer", "socket", "-m", "000100000006010100200001"])
+ parse_messages(["--framer", "socket", "-m", "00010000000401010101"])
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_server_callback():
+ """Test server/client with payload."""
+ task = asyncio.create_task(run_callback_server(cmdline=CMDARGS))
+ await asyncio.sleep(0.1)
+ testclient = setup_async_client(cmdline=CMDARGS)
+ await run_async_client(testclient, modbus_calls=run_client_test)
+ await asyncio.sleep(0.1)
+ await ServerAsyncStop()
+ await asyncio.sleep(0.1)
+ task.cancel()
+ await task
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+async def test_exp_client_custom_msg(mock_run_server):
+ """Test client with custom message."""
assert not mock_run_server
- if test_comm == "serial" and test_framer in (ModbusAsciiFramer, ModbusBinaryFramer):
- return
- if pytest.IS_WINDOWS and test_comm == "serial":
- return
- args = Commandline.copy()
- args.framer = test_framer
- args.comm = test_comm
- args.port = test_port + test_port_offset
- await run_client(test_comm, run_async_calls, args=args)
+
+ run_custom_client()
-@pytest.mark.parametrize("test_port_offset", [40])
-@pytest.mark.parametrize("test_comm, test_framer, test_port", [TEST_COMMS_FRAMER[0]])
-async def test_exp_forwarder(
+# to be updated:
+# modbus_forwarder.py
+#
+# to be converted:
+# v2.5.3
+
+
+# @pytest.mark.parametrize("test_port_offset", [40])
+# @pytest.mark.parametrize("test_comm, test_framer, test_port", [TEST_COMMS_FRAMER[0]])
+async def xtest_exp_forwarder(
test_comm,
test_framer,
test_port_offset,
@@ -163,8 +188,6 @@ async def test_exp_forwarder(
if pytest.IS_WINDOWS:
return
print(test_comm, test_framer, test_port_offset, test_port)
- pymodbus_apply_logging_config()
- # cmd_args = Commandline.copy()
# cmd_args.comm = test_comm
# cmd_args.framer = test_framer
# cmd_args.port = test_port + test_port_offset + 1
@@ -206,31 +229,3 @@ async def test_exp_forwarder(
# await ServerAsyncStop()
# await asyncio.sleep(0.1)
# task.cancel()
-
-
-@pytest.mark.parametrize("test_port_offset", [50])
-@pytest.mark.parametrize("test_comm, test_framer, test_port", [TEST_COMMS_FRAMER[0]])
-async def test_exp_payload(
- test_comm,
- test_framer,
- test_port_offset,
- test_port,
-):
- """Test server/client with payload."""
- pymodbus_apply_logging_config()
- args = Commandline.copy()
- args.port = test_port + test_port_offset
- args.comm = test_comm
- args.framer = test_framer
- run_args = setup_payload_server(args)
- task = asyncio.create_task(run_async_server(run_args))
- await asyncio.sleep(0.1)
- testclient = setup_async_client(args)
- await run_async_client(testclient, modbus_calls=run_payload_calls)
- await asyncio.sleep(0.1)
- await ServerAsyncStop()
- try:
- await asyncio.sleep(0.1)
- except asyncio.CancelledError:
- pass
- task.cancel()
diff --git a/test/test_exceptions.py b/test/test_exceptions.py
index fe9f8e327..369ad40fd 100644
--- a/test/test_exceptions.py
+++ b/test/test_exceptions.py
@@ -1,5 +1,5 @@
"""Test exceptions."""
-import unittest
+import pytest
from pymodbus.exceptions import (
ConnectionException,
@@ -10,28 +10,19 @@
)
-class SimpleExceptionsTest(unittest.TestCase):
+class TestExceptions: # pylint: disable=too-few-public-methods
"""Unittest for the pymodbus.exceptions module."""
- def setUp(self):
- """Initialize the test environment"""
- self.exceptions = [
- ModbusException("bad base"),
- ModbusIOException("bad register"),
- ParameterException("bad parameter"),
- NotImplementedException("bad function"),
- ConnectionException("bad connection"),
- ]
-
- def tearDown(self):
- """Clean up the test environment"""
+ exceptions = [
+ ModbusException("bad base"),
+ ModbusIOException("bad register"),
+ ParameterException("bad parameter"),
+ NotImplementedException("bad function"),
+ ConnectionException("bad connection"),
+ ]
def test_exceptions(self):
"""Test all module exceptions"""
for exc in self.exceptions:
- try:
+ with pytest.raises(ModbusException, match="Modbus Error:"):
raise exc
- except ModbusException as exc:
- self.assertTrue("Modbus Error:" in str(exc))
- return
- self.fail("Excepted a ModbusExceptions")
diff --git a/test/test_factory.py b/test/test_factory.py
index dbddaae0d..ea2d4902c 100644
--- a/test/test_factory.py
+++ b/test/test_factory.py
@@ -1,5 +1,5 @@
"""Test factory."""
-import unittest
+import pytest
from pymodbus.exceptions import MessageRegisterException, ModbusException
from pymodbus.factory import ClientDecoder, ServerDecoder
@@ -11,154 +11,143 @@ def _raise_exception(_):
raise ModbusException("something")
-class SimpleFactoryTest(unittest.TestCase):
+class TestFactory:
"""Unittest for the pymod.exceptions module."""
- def setUp(self):
- """Initialize the test environment"""
+ client = None
+ server = None
+ request = (
+ (0x01, b"\x01\x00\x01\x00\x01"), # read coils
+ (0x02, b"\x02\x00\x01\x00\x01"), # read discrete inputs
+ (0x03, b"\x03\x00\x01\x00\x01"), # read holding registers
+ (0x04, b"\x04\x00\x01\x00\x01"), # read input registers
+ (0x05, b"\x05\x00\x01\x00\x01"), # write single coil
+ (0x06, b"\x06\x00\x01\x00\x01"), # write single register
+ (0x07, b"\x07"), # read exception status
+ (0x08, b"\x08\x00\x00\x00\x00"), # read diagnostic
+ (0x0B, b"\x0b"), # get comm event counters
+ (0x0C, b"\x0c"), # get comm event log
+ (0x0F, b"\x0f\x00\x01\x00\x08\x01\x00\xff"), # write multiple coils
+ (0x10, b"\x10\x00\x01\x00\x02\x04\0xff\xff"), # write multiple registers
+ (0x11, b"\x11"), # report slave id
+ (
+ 0x14,
+ b"\x14\x0e\x06\x00\x04\x00\x01\x00\x02\x06\x00\x03\x00\x09\x00\x02",
+ ), # read file record
+ (
+ 0x15,
+ b"\x15\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d",
+ ), # write file record
+ (0x16, b"\x16\x00\x01\x00\xff\xff\x00"), # mask write register
+ (
+ 0x17,
+ b"\x17\x00\x01\x00\x01\x00\x01\x00\x01\x02\x12\x34",
+ ), # r/w multiple regs
+ (0x18, b"\x18\x00\x01"), # read fifo queue
+ (0x2B, b"\x2b\x0e\x01\x00"), # read device identification
+ )
+
+ response = (
+ (0x01, b"\x01\x01\x01"), # read coils
+ (0x02, b"\x02\x01\x01"), # read discrete inputs
+ (0x03, b"\x03\x02\x01\x01"), # read holding registers
+ (0x04, b"\x04\x02\x01\x01"), # read input registers
+ (0x05, b"\x05\x00\x01\x00\x01"), # write single coil
+ (0x06, b"\x06\x00\x01\x00\x01"), # write single register
+ (0x07, b"\x07\x00"), # read exception status
+ (0x08, b"\x08\x00\x00\x00\x00"), # read diagnostic
+ (0x0B, b"\x0b\x00\x00\x00\x00"), # get comm event counters
+ (0x0C, b"\x0c\x08\x00\x00\x01\x08\x01\x21\x20\x00"), # get comm event log
+ (0x0F, b"\x0f\x00\x01\x00\x08"), # write multiple coils
+ (0x10, b"\x10\x00\x01\x00\x02"), # write multiple registers
+ (0x11, b"\x11\x03\x05\x01\x54"), # report slave id (device specific)
+ (
+ 0x14,
+ b"\x14\x0c\x05\x06\x0d\xfe\x00\x20\x05\x06\x33\xcd\x00\x40",
+ ), # read file record
+ (
+ 0x15,
+ b"\x15\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d",
+ ), # write file record
+ (0x16, b"\x16\x00\x01\x00\xff\xff\x00"), # mask write register
+ (0x17, b"\x17\x02\x12\x34"), # read/write multiple registers
+ (0x18, b"\x18\x00\x01\x00\x01\x00\x00"), # read fifo queue
+ (
+ 0x2B,
+ b"\x2b\x0e\x01\x01\x00\x00\x01\x00\x01\x77",
+ ), # read device identification
+ )
+
+ exception = (
+ (0x81, b"\x81\x01\xd0\x50"), # illegal function exception
+ (0x82, b"\x82\x02\x90\xa1"), # illegal data address exception
+ (0x83, b"\x83\x03\x50\xf1"), # illegal data value exception
+ (0x84, b"\x84\x04\x13\x03"), # skave device failure exception
+ (0x85, b"\x85\x05\xd3\x53"), # acknowledge exception
+ (0x86, b"\x86\x06\x93\xa2"), # slave device busy exception
+ (0x87, b"\x87\x08\x53\xf2"), # memory parity exception
+ (0x88, b"\x88\x0a\x16\x06"), # gateway path unavailable exception
+ (0x89, b"\x89\x0b\xd6\x56"), # gateway target failed exception
+ )
+
+ bad = (
+ (0x80, b"\x80\x00\x00\x00"), # Unknown Function
+ (0x81, b"\x81\x00\x00\x00"), # error message
+ )
+
+ @pytest.fixture(autouse=True)
+ def _setup(self):
+ """Do common setup function."""
self.client = ClientDecoder()
self.server = ServerDecoder()
- self.request = (
- (0x01, b"\x01\x00\x01\x00\x01"), # read coils
- (0x02, b"\x02\x00\x01\x00\x01"), # read discrete inputs
- (0x03, b"\x03\x00\x01\x00\x01"), # read holding registers
- (0x04, b"\x04\x00\x01\x00\x01"), # read input registers
- (0x05, b"\x05\x00\x01\x00\x01"), # write single coil
- (0x06, b"\x06\x00\x01\x00\x01"), # write single register
- (0x07, b"\x07"), # read exception status
- (0x08, b"\x08\x00\x00\x00\x00"), # read diagnostic
- (0x0B, b"\x0b"), # get comm event counters
- (0x0C, b"\x0c"), # get comm event log
- (0x0F, b"\x0f\x00\x01\x00\x08\x01\x00\xff"), # write multiple coils
- (0x10, b"\x10\x00\x01\x00\x02\x04\0xff\xff"), # write multiple registers
- (0x11, b"\x11"), # report slave id
- (
- 0x14,
- b"\x14\x0e\x06\x00\x04\x00\x01\x00\x02\x06\x00\x03\x00\x09\x00\x02",
- ), # read file record
- (
- 0x15,
- b"\x15\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d",
- ), # write file record
- (0x16, b"\x16\x00\x01\x00\xff\xff\x00"), # mask write register
- (
- 0x17,
- b"\x17\x00\x01\x00\x01\x00\x01\x00\x01\x02\x12\x34",
- ), # r/w multiple regs
- (0x18, b"\x18\x00\x01"), # read fifo queue
- (0x2B, b"\x2b\x0e\x01\x00"), # read device identification
- )
-
- self.response = (
- (0x01, b"\x01\x01\x01"), # read coils
- (0x02, b"\x02\x01\x01"), # read discrete inputs
- (0x03, b"\x03\x02\x01\x01"), # read holding registers
- (0x04, b"\x04\x02\x01\x01"), # read input registers
- (0x05, b"\x05\x00\x01\x00\x01"), # write single coil
- (0x06, b"\x06\x00\x01\x00\x01"), # write single register
- (0x07, b"\x07\x00"), # read exception status
- (0x08, b"\x08\x00\x00\x00\x00"), # read diagnostic
- (0x0B, b"\x0b\x00\x00\x00\x00"), # get comm event counters
- (0x0C, b"\x0c\x08\x00\x00\x01\x08\x01\x21\x20\x00"), # get comm event log
- (0x0F, b"\x0f\x00\x01\x00\x08"), # write multiple coils
- (0x10, b"\x10\x00\x01\x00\x02"), # write multiple registers
- (0x11, b"\x11\x03\x05\x01\x54"), # report slave id (device specific)
- (
- 0x14,
- b"\x14\x0c\x05\x06\x0d\xfe\x00\x20\x05\x06\x33\xcd\x00\x40",
- ), # read file record
- (
- 0x15,
- b"\x15\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d",
- ), # write file record
- (0x16, b"\x16\x00\x01\x00\xff\xff\x00"), # mask write register
- (0x17, b"\x17\x02\x12\x34"), # read/write multiple registers
- (0x18, b"\x18\x00\x01\x00\x01\x00\x00"), # read fifo queue
- (
- 0x2B,
- b"\x2b\x0e\x01\x01\x00\x00\x01\x00\x01\x77",
- ), # read device identification
- )
-
- self.exception = (
- (0x81, b"\x81\x01\xd0\x50"), # illegal function exception
- (0x82, b"\x82\x02\x90\xa1"), # illegal data address exception
- (0x83, b"\x83\x03\x50\xf1"), # illegal data value exception
- (0x84, b"\x84\x04\x13\x03"), # skave device failure exception
- (0x85, b"\x85\x05\xd3\x53"), # acknowledge exception
- (0x86, b"\x86\x06\x93\xa2"), # slave device busy exception
- (0x87, b"\x87\x08\x53\xf2"), # memory parity exception
- (0x88, b"\x88\x0a\x16\x06"), # gateway path unavailable exception
- (0x89, b"\x89\x0b\xd6\x56"), # gateway target failed exception
- )
-
- self.bad = (
- (0x80, b"\x80\x00\x00\x00"), # Unknown Function
- (0x81, b"\x81\x00\x00\x00"), # error message
- )
-
- def tearDown(self):
- """Clean up the test environment"""
- del self.bad
- del self.request
- del self.response
def test_exception_lookup(self):
"""Test that we can look up exception messages"""
for func, _ in self.exception:
response = self.client.lookupPduClass(func)
- self.assertNotEqual(response, None)
-
- for func, _ in self.exception:
- response = self.server.lookupPduClass(func)
- self.assertNotEqual(response, None)
+ assert response
def test_response_lookup(self):
"""Test a working response factory lookup"""
for func, _ in self.response:
response = self.client.lookupPduClass(func)
- self.assertNotEqual(response, None)
+ assert response
def test_request_lookup(self):
"""Test a working request factory lookup"""
for func, _ in self.request:
request = self.client.lookupPduClass(func)
- self.assertNotEqual(request, None)
+ assert request
def test_response_working(self):
"""Test a working response factory decoders"""
- for func, msg in self.response:
+ for _func, msg in self.response:
self.client.decode(msg)
def test_response_errors(self):
"""Test a response factory decoder exceptions"""
- self.assertRaises(
- ModbusException,
- self.client._helper, # pylint: disable=protected-access
- self.bad[0][1],
- )
- self.assertEqual(
- self.client.decode(self.bad[1][1]).function_code,
- self.bad[1][0],
- "Failed to decode error PDU",
- )
+ with pytest.raises(ModbusException):
+ self.client._helper(self.bad[0][1]) # pylint: disable=protected-access
+ assert (
+ self.client.decode(self.bad[1][1]).function_code == self.bad[1][0]
+ ), "Failed to decode error PDU"
def test_requests_working(self):
"""Test a working request factory decoders"""
- for func, msg in self.request:
+ for _func, msg in self.request:
self.server.decode(msg)
def test_client_factory_fails(self):
"""Tests that a client factory will fail to decode a bad message"""
self.client._helper = _raise_exception # pylint: disable=protected-access
actual = self.client.decode(None)
- self.assertEqual(actual, None)
+ assert not actual
def test_server_factory_fails(self):
"""Tests that a server factory will fail to decode a bad message"""
self.server._helper = _raise_exception # pylint: disable=protected-access
actual = self.server.decode(None)
- self.assertEqual(actual, None)
+ assert not actual
def test_server_register_custom_request(self):
"""Test server register custom request."""
@@ -174,16 +163,16 @@ class NoCustomRequest: # pylint: disable=too-few-public-methods
function_code = 0xFF
self.server.register(CustomRequest)
- self.assertTrue(self.client.lookupPduClass(CustomRequest.function_code))
+ assert self.client.lookupPduClass(CustomRequest.function_code)
CustomRequest.sub_function_code = 0xFF
self.server.register(CustomRequest)
- self.assertTrue(self.server.lookupPduClass(CustomRequest.function_code))
+ assert self.server.lookupPduClass(CustomRequest.function_code)
try:
func_raised = False
self.server.register(NoCustomRequest)
except MessageRegisterException:
func_raised = True
- self.assertTrue(func_raised)
+ assert func_raised
def test_client_register_custom_response(self):
"""Test client register custom response."""
@@ -199,16 +188,16 @@ class NoCustomResponse: # pylint: disable=too-few-public-methods
function_code = 0xFF
self.client.register(CustomResponse)
- self.assertTrue(self.client.lookupPduClass(CustomResponse.function_code))
+ assert self.client.lookupPduClass(CustomResponse.function_code)
CustomResponse.sub_function_code = 0xFF
self.client.register(CustomResponse)
- self.assertTrue(self.client.lookupPduClass(CustomResponse.function_code))
+ assert self.client.lookupPduClass(CustomResponse.function_code)
try:
func_raised = False
self.client.register(NoCustomResponse)
except MessageRegisterException:
func_raised = True
- self.assertTrue(func_raised)
+ assert func_raised
# ---------------------------------------------------------------------------#
# I don't actually know what is supposed to be returned here, I assume that
@@ -219,9 +208,7 @@ def test_request_errors(self):
"""Test a request factory decoder exceptions"""
for func, msg in self.bad:
result = self.server.decode(msg)
- self.assertEqual(result.ErrorCode, 1, "Failed to decode invalid requests")
- self.assertEqual(
- result.execute(None).function_code,
- func,
- "Failed to create correct response message",
- )
+ assert result.ErrorCode == 1, "Failed to decode invalid requests"
+ assert (
+ result.execute(None).function_code == func
+ ), "Failed to create correct response message"
diff --git a/test/test_file_message.py b/test/test_file_message.py
index 49558bacf..0214bb54b 100644
--- a/test/test_file_message.py
+++ b/test/test_file_message.py
@@ -6,7 +6,6 @@
* Read/Write Discretes
* Read Coils
"""
-import unittest
from test.conftest import MockContext
from pymodbus.file_message import (
@@ -28,49 +27,35 @@
# ---------------------------------------------------------------------------#
-class ModbusBitMessageTests(unittest.TestCase):
+class TestBitMessage:
"""Modbus bit message tests."""
- # -----------------------------------------------------------------------#
- # Setup/TearDown
- # -----------------------------------------------------------------------#
-
- def setUp(self):
- """Initialize the test environment and builds request/result encoding pairs."""
-
- def tearDown(self):
- """Clean up the test environment"""
-
- # -----------------------------------------------------------------------#
- # Read Fifo Queue
- # -----------------------------------------------------------------------#
-
def test_read_fifo_queue_request_encode(self):
"""Test basic bit message encoding/decoding"""
handle = ReadFifoQueueRequest(0x1234)
result = handle.encode()
- self.assertEqual(result, b"\x12\x34")
+ assert result == b"\x12\x34"
def test_read_fifo_queue_request_decode(self):
"""Test basic bit message encoding/decoding"""
handle = ReadFifoQueueRequest(0x0000)
handle.decode(b"\x12\x34")
- self.assertEqual(handle.address, 0x1234)
+ assert handle.address == 0x1234
def test_read_fifo_queue_request(self):
"""Test basic bit message encoding/decoding"""
context = MockContext()
handle = ReadFifoQueueRequest(0x1234)
result = handle.execute(context)
- self.assertTrue(isinstance(result, ReadFifoQueueResponse))
+ assert isinstance(result, ReadFifoQueueResponse)
handle.address = -1
result = handle.execute(context)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
handle.values = [0x00] * 33
result = handle.execute(context)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
def test_read_fifo_queue_request_error(self):
"""Test basic bit message encoding/decoding"""
@@ -78,27 +63,27 @@ def test_read_fifo_queue_request_error(self):
handle = ReadFifoQueueRequest(0x1234)
handle.values = [0x00] * 32
result = handle.execute(context)
- self.assertEqual(result.function_code, 0x98)
+ assert result.function_code == 0x98
def test_read_fifo_queue_response_encode(self):
"""Test that the read fifo queue response can encode"""
message = TEST_MESSAGE
handle = ReadFifoQueueResponse([1, 2, 3, 4])
result = handle.encode()
- self.assertEqual(result, message)
+ assert result == message
def test_read_fifo_queue_response_decode(self):
"""Test that the read fifo queue response can decode"""
message = TEST_MESSAGE
handle = ReadFifoQueueResponse([1, 2, 3, 4])
handle.decode(message)
- self.assertEqual(handle.values, [1, 2, 3, 4])
+ assert handle.values == [1, 2, 3, 4]
def test_rtu_frame_size(self):
"""Test that the read fifo queue response can decode"""
message = TEST_MESSAGE
result = ReadFifoQueueResponse.calculateRtuFrameSize(message)
- self.assertEqual(result, 14)
+ assert result == 14
# -----------------------------------------------------------------------#
# File Record
@@ -109,8 +94,8 @@ def test_file_record_length(self):
record = FileRecord(
file_number=0x01, record_number=0x02, record_data=b"\x00\x01\x02\x04"
)
- self.assertEqual(record.record_length, 0x02)
- self.assertEqual(record.response_length, 0x05)
+ assert record.record_length == 0x02
+ assert record.response_length == 0x05
def test_file_record_compare(self):
"""Test file record comparison operations"""
@@ -126,15 +111,15 @@ def test_file_record_compare(self):
record4 = FileRecord(
file_number=0x01, record_number=0x02, record_data=b"\x00\x01\x02\x04"
)
- self.assertTrue(record1 == record4)
- self.assertTrue(record1 != record2)
- self.assertNotEqual(record1, record2)
- self.assertNotEqual(record1, record3)
- self.assertNotEqual(record2, record3)
- self.assertEqual(record1, record4)
- self.assertEqual(str(record1), "FileRecord(file=1, record=2, length=2)")
- self.assertEqual(str(record2), "FileRecord(file=1, record=2, length=2)")
- self.assertEqual(str(record3), "FileRecord(file=2, record=3, length=2)")
+ assert record1 == record4
+ assert record1 != record2
+ assert record1 != record2
+ assert record1 != record3
+ assert record2 != record3
+ assert record1 == record4
+ assert str(record1) == "FileRecord(file=1, record=2, length=2)"
+ assert str(record2) == "FileRecord(file=1, record=2, length=2)"
+ assert str(record3) == "FileRecord(file=2, record=3, length=2)"
# -----------------------------------------------------------------------#
# Read File Record Request
@@ -145,7 +130,7 @@ def test_read_file_record_request_encode(self):
records = [FileRecord(file_number=0x01, record_number=0x02)]
handle = ReadFileRecordRequest(records)
result = handle.encode()
- self.assertEqual(result, b"\x07\x06\x00\x01\x00\x02\x00\x00")
+ assert result == b"\x07\x06\x00\x01\x00\x02\x00\x00"
def test_read_file_record_request_decode(self):
"""Test basic bit message encoding/decoding"""
@@ -153,7 +138,7 @@ def test_read_file_record_request_decode(self):
request = b"\x0e\x06\x00\x04\x00\x01\x00\x02\x06\x00\x03\x00\x09\x00\x02"
handle = ReadFileRecordRequest()
handle.decode(request)
- self.assertEqual(handle.records[0], record)
+ assert handle.records[0] == record
def test_read_file_record_request_rtu_frame_size(self):
"""Test basic bit message encoding/decoding"""
@@ -162,13 +147,13 @@ def test_read_file_record_request_rtu_frame_size(self):
)
handle = ReadFileRecordRequest()
size = handle.calculateRtuFrameSize(request)
- self.assertEqual(size, 0x0E + 5)
+ assert size == 0x0E + 5
def test_read_file_record_request_execute(self):
"""Test basic bit message encoding/decoding"""
handle = ReadFileRecordRequest()
result = handle.execute(None)
- self.assertTrue(isinstance(result, ReadFileRecordResponse))
+ assert isinstance(result, ReadFileRecordResponse)
# -----------------------------------------------------------------------#
# Read File Record Response
@@ -179,7 +164,7 @@ def test_read_file_record_response_encode(self):
records = [FileRecord(record_data=b"\x00\x01\x02\x03")]
handle = ReadFileRecordResponse(records)
result = handle.encode()
- self.assertEqual(result, b"\x06\x06\x02\x00\x01\x02\x03")
+ assert result == b"\x06\x06\x02\x00\x01\x02\x03"
def test_read_file_record_response_decode(self):
"""Test basic bit message encoding/decoding"""
@@ -189,14 +174,14 @@ def test_read_file_record_response_decode(self):
request = b"\x0c\x05\x06\x0d\xfe\x00\x20\x05\x05\x06\x33\xcd\x00\x40"
handle = ReadFileRecordResponse()
handle.decode(request)
- self.assertEqual(handle.records[0], record)
+ assert handle.records[0] == record
def test_read_file_record_response_rtu_frame_size(self):
"""Test basic bit message encoding/decoding"""
request = b"\x00\x00\x0c\x05\x06\x0d\xfe\x00\x20\x05\x05\x06\x33\xcd\x00\x40"
handle = ReadFileRecordResponse()
size = handle.calculateRtuFrameSize(request)
- self.assertEqual(size, 0x0C + 5)
+ assert size == 0x0C + 5
# -----------------------------------------------------------------------#
# Write File Record Request
@@ -211,7 +196,7 @@ def test_write_file_record_request_encode(self):
]
handle = WriteFileRecordRequest(records)
result = handle.encode()
- self.assertEqual(result, b"\x0b\x06\x00\x01\x00\x02\x00\x02\x00\x01\x02\x03")
+ assert result == b"\x0b\x06\x00\x01\x00\x02\x00\x02\x00\x01\x02\x03"
def test_write_file_record_request_decode(self):
"""Test basic bit message encoding/decoding"""
@@ -223,20 +208,20 @@ def test_write_file_record_request_decode(self):
request = b"\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d"
handle = WriteFileRecordRequest()
handle.decode(request)
- self.assertEqual(handle.records[0], record)
+ assert handle.records[0] == record
def test_write_file_record_request_rtu_frame_size(self):
"""Test write file record request rtu frame size calculation"""
request = b"\x00\x00\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d"
handle = WriteFileRecordRequest()
size = handle.calculateRtuFrameSize(request)
- self.assertEqual(size, 0x0D + 5)
+ assert size == 0x0D + 5
def test_write_file_record_request_execute(self):
"""Test basic bit message encoding/decoding"""
handle = WriteFileRecordRequest()
result = handle.execute(None)
- self.assertTrue(isinstance(result, WriteFileRecordResponse))
+ assert isinstance(result, WriteFileRecordResponse)
# -----------------------------------------------------------------------#
# Write File Record Response
@@ -251,7 +236,7 @@ def test_write_file_record_response_encode(self):
]
handle = WriteFileRecordResponse(records)
result = handle.encode()
- self.assertEqual(result, b"\x0b\x06\x00\x01\x00\x02\x00\x02\x00\x01\x02\x03")
+ assert result == b"\x0b\x06\x00\x01\x00\x02\x00\x02\x00\x01\x02\x03"
def test_write_file_record_response_decode(self):
"""Test basic bit message encoding/decoding"""
@@ -263,11 +248,11 @@ def test_write_file_record_response_decode(self):
request = b"\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d"
handle = WriteFileRecordResponse()
handle.decode(request)
- self.assertEqual(handle.records[0], record)
+ assert handle.records[0] == record
def test_write_file_record_response_rtu_frame_size(self):
"""Test write file record response rtu frame size calculation"""
request = b"\x00\x00\x0d\x06\x00\x04\x00\x07\x00\x03\x06\xaf\x04\xbe\x10\x0d"
handle = WriteFileRecordResponse()
size = handle.calculateRtuFrameSize(request)
- self.assertEqual(size, 0x0D + 5)
+ assert size == 0x0D + 5
diff --git a/test/test_framers.py b/test/test_framers.py
index 2132088da..3f7b55702 100644
--- a/test/test_framers.py
+++ b/test/test_framers.py
@@ -1,5 +1,5 @@
"""Test framers."""
-from unittest.mock import Mock, patch
+from unittest import mock
import pytest
@@ -16,14 +16,14 @@
TEST_MESSAGE = b"\x00\x01\x00\x01\x00\n\xec\x1c"
-@pytest.fixture
-def rtu_framer():
+@pytest.fixture(name="rtu_framer")
+def fixture_rtu_framer():
"""RTU framer."""
return ModbusRtuFramer(ClientDecoder())
-@pytest.fixture
-def ascii_framer():
+@pytest.fixture(name="ascii_framer")
+def fixture_ascii_framer():
"""Ascii framer."""
return ModbusAsciiFramer(ClientDecoder())
@@ -82,8 +82,8 @@ def test_framer_initialization(framer):
]
-@pytest.mark.parametrize("data", [(b"", {}), (b"abcd", {"fcode": 98, "unit": 97})])
-def test_decode_data(rtu_framer, data): # pylint: disable=redefined-outer-name
+@pytest.mark.parametrize("data", [(b"", {}), (b"abcd", {"fcode": 98, "slave": 97})])
+def test_decode_data(rtu_framer, data):
"""Test decode data."""
data, expected = data
decoded = rtu_framer.decode_data(data)
@@ -99,7 +99,7 @@ def test_decode_data(rtu_framer, data): # pylint: disable=redefined-outer-name
(b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAC", False), # invalid frame CRC
],
)
-def test_check_frame(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_check_frame(rtu_framer, data):
"""Test check frame."""
data, expected = data
rtu_framer._buffer = data # pylint: disable=protected-access
@@ -118,7 +118,7 @@ def test_check_frame(rtu_framer, data): # pylint: disable=redefined-outer-name
),
],
)
-def test_rtu_advance_framer(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_rtu_advance_framer(rtu_framer, data):
"""Test rtu advance framer."""
before_buf, before_header, after_buf = data
@@ -134,7 +134,7 @@ def test_rtu_advance_framer(rtu_framer, data): # pylint: disable=redefined-oute
@pytest.mark.parametrize("data", [b"", b"abcd"])
-def test_rtu_reset_framer(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_rtu_reset_framer(rtu_framer, data):
"""Test rtu reset framer."""
rtu_framer._buffer = data # pylint: disable=protected-access
rtu_framer.resetFrame()
@@ -143,7 +143,6 @@ def test_rtu_reset_framer(rtu_framer, data): # pylint: disable=redefined-outer-
"len": 0,
"crc": b"\x00\x00",
}
- assert rtu_framer._buffer == b"" # pylint: disable=protected-access
@pytest.mark.parametrize(
@@ -158,7 +157,7 @@ def test_rtu_reset_framer(rtu_framer, data): # pylint: disable=redefined-outer-
(b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAD\xAB\xCD", True),
],
)
-def test_is_frame_ready(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_is_frame_ready(rtu_framer, data):
"""Test is frame ready."""
data, expected = data
rtu_framer._buffer = data # pylint: disable=protected-access
@@ -176,9 +175,7 @@ def test_is_frame_ready(rtu_framer, data): # pylint: disable=redefined-outer-na
b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x43",
],
)
-def test_rtu_populate_header_fail(
- rtu_framer, data
-): # pylint: disable=redefined-outer-name
+def test_rtu_populate_header_fail(rtu_framer, data):
"""Test rtu populate header fail."""
with pytest.raises(IndexError):
rtu_framer.populateHeader(data)
@@ -197,33 +194,33 @@ def test_rtu_populate_header_fail(
),
],
)
-def test_rtu_populate_header(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_rtu_populate_header(rtu_framer, data):
"""Test rtu populate header."""
buffer, expected = data
rtu_framer.populateHeader(buffer)
assert rtu_framer._header == expected # pylint: disable=protected-access
-def test_add_to_frame(rtu_framer): # pylint: disable=redefined-outer-name
+def test_add_to_frame(rtu_framer):
"""Test add to frame."""
assert rtu_framer._buffer == b"" # pylint: disable=protected-access
rtu_framer.addToFrame(b"abcd")
assert rtu_framer._buffer == b"abcd" # pylint: disable=protected-access
-def test_get_frame(rtu_framer): # pylint: disable=redefined-outer-name
+def test_get_frame(rtu_framer):
"""Test get frame."""
rtu_framer.addToFrame(b"\x02\x01\x01\x00Q\xcc")
rtu_framer.populateHeader(b"\x02\x01\x01\x00Q\xcc")
assert rtu_framer.getFrame() == b"\x01\x01\x00"
-def test_populate_result(rtu_framer): # pylint: disable=redefined-outer-name
+def test_populate_result(rtu_framer):
"""Test populate result."""
rtu_framer._header["uid"] = 255 # pylint: disable=protected-access
- result = Mock()
+ result = mock.Mock()
rtu_framer.populateResult(result)
- assert result.unit_id == 255
+ assert result.slave_id == 255
@pytest.mark.parametrize(
@@ -255,36 +252,36 @@ def test_populate_result(rtu_framer): # pylint: disable=redefined-outer-name
(
b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAD",
16,
- True,
False,
- ), # incorrect unit id
+ False,
+ ), # incorrect slave id
(b"\x11\x03\x06\xAE\x41\x56\x52\x43\x40\x49\xAD\x11\x03", 17, False, True),
# good frame + part of next frame
],
)
-def test_rtu_incoming_packet(rtu_framer, data): # pylint: disable=redefined-outer-name
+def test_rtu_incoming_packet(rtu_framer, data):
"""Test rtu process incoming packet."""
- buffer, units, reset_called, process_called = data
+ buffer, slaves, reset_called, process_called = data
- with patch.object(
+ with mock.patch.object(
rtu_framer,
"_process",
wraps=rtu_framer._process, # pylint: disable=protected-access
- ) as mock_process, patch.object(
+ ) as mock_process, mock.patch.object(
rtu_framer, "resetFrame", wraps=rtu_framer.resetFrame
) as mock_reset:
- rtu_framer.processIncomingPacket(buffer, Mock(), units)
+ rtu_framer.processIncomingPacket(buffer, mock.Mock(), slaves)
assert mock_process.call_count == (1 if process_called else 0)
assert mock_reset.call_count == (1 if reset_called else 0)
-def test_build_packet(rtu_framer): # pylint: disable=redefined-outer-name
+def test_build_packet(rtu_framer):
"""Test build packet."""
message = ReadCoilsRequest(1, 10)
assert rtu_framer.buildPacket(message) == TEST_MESSAGE
-def test_send_packet(rtu_framer): # pylint: disable=redefined-outer-name
+def test_send_packet(rtu_framer):
"""Test send packet."""
message = TEST_MESSAGE
client = ModbusBaseClient(framer=ModbusRtuFramer)
@@ -292,24 +289,24 @@ def test_send_packet(rtu_framer): # pylint: disable=redefined-outer-name
client.silent_interval = 1
client.last_frame_end = 1
client.params.timeout = 0.25
- client.idle_time = Mock(return_value=1)
- client.send = Mock(return_value=len(message))
+ client.idle_time = mock.Mock(return_value=1)
+ client.send = mock.Mock(return_value=len(message))
rtu_framer.client = client
assert rtu_framer.sendPacket(message) == len(message)
client.state = ModbusTransactionState.PROCESSING_REPLY
assert rtu_framer.sendPacket(message) == len(message)
-def test_recv_packet(rtu_framer): # pylint: disable=redefined-outer-name
+def test_recv_packet(rtu_framer):
"""Test receive packet."""
message = TEST_MESSAGE
- client = Mock()
+ client = mock.Mock()
client.recv.return_value = message
rtu_framer.client = client
assert rtu_framer.recvPacket(len(message)) == message
-def test_process(rtu_framer): # pylint: disable=redefined-outer-name
+def test_process(rtu_framer):
"""Test process."""
rtu_framer._buffer = TEST_MESSAGE # pylint: disable=protected-access
@@ -317,7 +314,7 @@ def test_process(rtu_framer): # pylint: disable=redefined-outer-name
rtu_framer._process(None) # pylint: disable=protected-access
-def test_get_raw_frame(rtu_framer): # pylint: disable=redefined-outer-name
+def test_get_raw_frame(rtu_framer):
"""Test get raw frame."""
rtu_framer._buffer = TEST_MESSAGE # pylint: disable=protected-access
assert (
@@ -326,20 +323,20 @@ def test_get_raw_frame(rtu_framer): # pylint: disable=redefined-outer-name
)
-def test_validate_unit_id(rtu_framer): # pylint: disable=redefined-outer-name
- """Test validate unit."""
+def test_validate__slave_id(rtu_framer):
+ """Test validate slave."""
rtu_framer.populateHeader(TEST_MESSAGE)
- assert rtu_framer._validate_unit_id([0], False) # pylint: disable=protected-access
- assert rtu_framer._validate_unit_id([1], True) # pylint: disable=protected-access
+ assert rtu_framer._validate_slave_id([0], False) # pylint: disable=protected-access
+ assert rtu_framer._validate_slave_id([1], True) # pylint: disable=protected-access
@pytest.mark.parametrize("data", [b":010100010001FC\r\n", b""])
-def test_decode_ascii_data(ascii_framer, data): # pylint: disable=redefined-outer-name
+def test_decode_ascii_data(ascii_framer, data):
"""Test decode ascii."""
data = ascii_framer.decode_data(data)
assert isinstance(data, dict)
if data:
- assert data.get("unit") == 1
+ assert data.get("slave") == 1
assert data.get("fcode") == 1
else:
assert not data
diff --git a/test/test_logging.py b/test/test_logging.py
index b1c5f911f..7c12f015a 100644
--- a/test/test_logging.py
+++ b/test/test_logging.py
@@ -1,6 +1,6 @@
"""Test datastore."""
import logging
-from unittest.mock import patch
+from unittest import mock
import pytest
@@ -12,7 +12,7 @@ class TestLogging:
def test_log_dont_call_build_msg(self):
"""Verify that build_msg is not called unnecessary"""
- with patch("pymodbus.logging.Log.build_msg") as build_msg_mock:
+ with mock.patch("pymodbus.logging.Log.build_msg") as build_msg_mock:
Log.setLevel(logging.INFO)
Log.debug("test")
build_msg_mock.assert_not_called()
@@ -28,7 +28,7 @@ def test_log_simple(self):
assert log_txt == txt
@pytest.mark.parametrize(
- "txt, result, params",
+ ("txt", "result", "params"),
[
("string {} {} {}", "string 101 102 103", (101, 102, 103)),
("string {}", "string 0x41 0x42 0x43 0x44", (b"ABCD", ":hex")),
diff --git a/test/test_mei_messages.py b/test/test_mei_messages.py
index f7131fac1..b1dbab820 100644
--- a/test/test_mei_messages.py
+++ b/test/test_mei_messages.py
@@ -3,7 +3,7 @@
This fixture tests the functionality of all the
mei based request/response messages:
"""
-import unittest
+import pytest
from pymodbus.constants import DeviceInformation
from pymodbus.device import ModbusControlBlock
@@ -21,7 +21,7 @@
TEST_MESSAGE = b"\x00\x07Company\x01\x07Product\x02\x07v2.1.12"
-class ModbusMeiMessageTest(unittest.TestCase):
+class TestMeiMessage:
"""Unittest for the pymodbus.mei_message module."""
# -----------------------------------------------------------------------#
@@ -33,15 +33,15 @@ def test_read_device_information_request_encode(self):
params = {"read_code": DeviceInformation.Basic, "object_id": 0x00}
handle = ReadDeviceInformationRequest(**params)
result = handle.encode()
- self.assertEqual(result, b"\x0e\x01\x00")
- self.assertEqual("ReadDeviceInformationRequest(1,0)", str(handle))
+ assert result == b"\x0e\x01\x00"
+ assert str(handle) == "ReadDeviceInformationRequest(1,0)"
def test_read_device_information_request_decode(self):
"""Test basic bit message encoding/decoding"""
handle = ReadDeviceInformationRequest()
handle.decode(b"\x0e\x01\x00")
- self.assertEqual(handle.read_code, DeviceInformation.Basic)
- self.assertEqual(handle.object_id, 0x00)
+ assert handle.read_code == DeviceInformation.Basic
+ assert not handle.object_id
def test_read_device_information_request(self):
"""Test basic bit message encoding/decoding"""
@@ -54,30 +54,30 @@ def test_read_device_information_request(self):
handle = ReadDeviceInformationRequest()
result = handle.execute(context)
- self.assertTrue(isinstance(result, ReadDeviceInformationResponse))
- self.assertEqual(result.information[0x00], "Company")
- self.assertEqual(result.information[0x01], "Product")
- self.assertEqual(result.information[0x02], TEST_VERSION)
- with self.assertRaises(KeyError):
+ assert isinstance(result, ReadDeviceInformationResponse)
+ assert result.information[0x00] == "Company"
+ assert result.information[0x01] == "Product"
+ assert result.information[0x02] == TEST_VERSION
+ with pytest.raises(KeyError):
_ = result.information[0x81]
handle = ReadDeviceInformationRequest(
read_code=DeviceInformation.Extended, object_id=0x80
)
result = handle.execute(context)
- self.assertEqual(result.information[0x81], ["Test", "Repeated"])
+ assert result.information[0x81] == ["Test", "Repeated"]
def test_read_device_information_request_error(self):
"""Test basic bit message encoding/decoding"""
handle = ReadDeviceInformationRequest()
handle.read_code = -1
- self.assertEqual(handle.execute(None).function_code, 0xAB)
+ assert handle.execute(None).function_code == 0xAB
handle.read_code = 0x05
- self.assertEqual(handle.execute(None).function_code, 0xAB)
+ assert handle.execute(None).function_code == 0xAB
handle.object_id = -1
- self.assertEqual(handle.execute(None).function_code, 0xAB)
+ assert handle.execute(None).function_code == 0xAB
handle.object_id = 0x100
- self.assertEqual(handle.execute(None).function_code, 0xAB)
+ assert handle.execute(None).function_code == 0xAB
def test_read_device_information_encode(self):
"""Test that the read fifo queue response can encode"""
@@ -92,8 +92,8 @@ def test_read_device_information_encode(self):
read_code=DeviceInformation.Basic, information=dataset
)
result = handle.encode()
- self.assertEqual(result, message)
- self.assertEqual("ReadDeviceInformationResponse(1)", str(handle))
+ assert result == message
+ assert str(handle) == "ReadDeviceInformationResponse(1)"
dataset = {
0x00: "Company",
@@ -108,7 +108,7 @@ def test_read_device_information_encode(self):
read_code=DeviceInformation.Extended, information=dataset
)
result = handle.encode()
- self.assertEqual(result, message)
+ assert result == message
def test_read_device_information_encode_long(self):
"""Test that the read fifo queue response can encode"""
@@ -132,8 +132,8 @@ def test_read_device_information_encode_long(self):
read_code=DeviceInformation.Basic, information=dataset
)
result = handle.encode()
- self.assertEqual(result, message)
- self.assertEqual("ReadDeviceInformationResponse(1)", str(handle))
+ assert result == message
+ assert str(handle) == "ReadDeviceInformationResponse(1)"
def test_read_device_information_decode(self):
"""Test that the read device information response can decode"""
@@ -142,12 +142,12 @@ def test_read_device_information_decode(self):
message += b"\x81\x04Test\x81\x08Repeated\x81\x07Another"
handle = ReadDeviceInformationResponse(read_code=0x00, information=[])
handle.decode(message)
- self.assertEqual(handle.read_code, DeviceInformation.Basic)
- self.assertEqual(handle.conformity, 0x01)
- self.assertEqual(handle.information[0x00], b"Company")
- self.assertEqual(handle.information[0x01], b"Product")
- self.assertEqual(handle.information[0x02], TEST_VERSION)
- self.assertEqual(handle.information[0x81], [b"Test", b"Repeated", b"Another"])
+ assert handle.read_code == DeviceInformation.Basic
+ assert handle.conformity == 0x01
+ assert handle.information[0x00] == b"Company"
+ assert handle.information[0x01] == b"Product"
+ assert handle.information[0x02] == TEST_VERSION
+ assert handle.information[0x81] == [b"Test", b"Repeated", b"Another"]
def test_rtu_frame_size(self):
"""Test that the read device information response can decode"""
@@ -155,7 +155,7 @@ def test_rtu_frame_size(self):
b"\x04\x2B\x0E\x01\x81\x00\x01\x01\x00\x06\x66\x6F\x6F\x62\x61\x72\xD7\x3B"
)
result = ReadDeviceInformationResponse.calculateRtuFrameSize(message)
- self.assertEqual(result, 18)
+ assert result == 18
message = b"\x00\x2B\x0E\x02\x00\x4D\x47"
result = ReadDeviceInformationRequest.calculateRtuFrameSize(message)
- self.assertEqual(result, 7)
+ assert result == 7
diff --git a/test/test_other_messages.py b/test/test_other_messages.py
index fcc53217f..ed2174e20 100644
--- a/test/test_other_messages.py
+++ b/test/test_other_messages.py
@@ -1,96 +1,88 @@
"""Test other messages."""
-import unittest
from unittest import mock
import pymodbus.other_message as pymodbus_message
-class ModbusOtherMessageTest(unittest.TestCase):
+class TestOtherMessage:
"""Unittest for the pymodbus.other_message module."""
- def setUp(self):
- """Do setup."""
- self.requests = [
- pymodbus_message.ReadExceptionStatusRequest,
- pymodbus_message.GetCommEventCounterRequest,
- pymodbus_message.GetCommEventLogRequest,
- pymodbus_message.ReportSlaveIdRequest,
- ]
-
- self.responses = [
- lambda: pymodbus_message.ReadExceptionStatusResponse(0x12),
- lambda: pymodbus_message.GetCommEventCounterResponse(0x12),
- pymodbus_message.GetCommEventLogResponse,
- lambda: pymodbus_message.ReportSlaveIdResponse(0x12),
- ]
-
- def tearDown(self):
- """Clean up the test environment."""
- del self.requests
- del self.responses
+ requests = [
+ pymodbus_message.ReadExceptionStatusRequest,
+ pymodbus_message.GetCommEventCounterRequest,
+ pymodbus_message.GetCommEventLogRequest,
+ pymodbus_message.ReportSlaveIdRequest,
+ ]
+
+ responses = [
+ lambda: pymodbus_message.ReadExceptionStatusResponse(0x12),
+ lambda: pymodbus_message.GetCommEventCounterResponse(0x12),
+ pymodbus_message.GetCommEventLogResponse,
+ lambda: pymodbus_message.ReportSlaveIdResponse(0x12),
+ ]
def test_other_messages_to_string(self):
"""Test other messages to string."""
for message in self.requests:
- self.assertNotEqual(str(message()), None)
+ assert str(message())
for message in self.responses:
- self.assertNotEqual(str(message()), None)
+ assert str(message())
def test_read_exception_status(self):
"""Test read exception status."""
request = pymodbus_message.ReadExceptionStatusRequest()
request.decode(b"\x12")
- self.assertEqual(request.encode(), b"")
- self.assertEqual(request.execute().function_code, 0x07)
+ assert not request.encode()
+ assert request.execute().function_code == 0x07
response = pymodbus_message.ReadExceptionStatusResponse(0x12)
- self.assertEqual(response.encode(), b"\x12")
+ assert response.encode() == b"\x12"
response.decode(b"\x12")
- self.assertEqual(response.status, 0x12)
+ assert response.status == 0x12
def test_get_comm_event_counter(self):
"""Test get comm event counter."""
request = pymodbus_message.GetCommEventCounterRequest()
request.decode(b"\x12")
- self.assertEqual(request.encode(), b"")
- self.assertEqual(request.execute().function_code, 0x0B)
+ assert not request.encode()
+ assert request.execute().function_code == 0x0B
response = pymodbus_message.GetCommEventCounterResponse(0x12)
- self.assertEqual(response.encode(), b"\x00\x00\x00\x12")
+ assert response.encode() == b"\x00\x00\x00\x12"
response.decode(b"\x00\x00\x00\x12")
- self.assertEqual(response.status, True)
- self.assertEqual(response.count, 0x12)
+ assert response.status
+ assert response.count == 0x12
response.status = False
- self.assertEqual(response.encode(), b"\xFF\xFF\x00\x12")
+ assert response.encode() == b"\xFF\xFF\x00\x12"
def test_get_comm_event_log(self):
"""Test get comm event log."""
request = pymodbus_message.GetCommEventLogRequest()
request.decode(b"\x12")
- self.assertEqual(request.encode(), b"")
- self.assertEqual(request.execute().function_code, 0x0C)
+ assert not request.encode()
+ assert request.execute().function_code == 0x0C
response = pymodbus_message.GetCommEventLogResponse()
- self.assertEqual(response.encode(), b"\x06\x00\x00\x00\x00\x00\x00")
+ assert response.encode() == b"\x06\x00\x00\x00\x00\x00\x00"
response.decode(b"\x06\x00\x00\x00\x12\x00\x12")
- self.assertEqual(response.status, True)
- self.assertEqual(response.message_count, 0x12)
- self.assertEqual(response.event_count, 0x12)
- self.assertEqual(response.events, [])
+ assert response.status
+ assert response.message_count == 0x12
+ assert response.event_count == 0x12
+ assert not response.events
response.status = False
- self.assertEqual(response.encode(), b"\x06\xff\xff\x00\x12\x00\x12")
+ assert response.encode() == b"\x06\xff\xff\x00\x12\x00\x12"
def test_get_comm_event_log_with_events(self):
"""Test get comm event log with events."""
response = pymodbus_message.GetCommEventLogResponse(events=[0x12, 0x34, 0x56])
- self.assertEqual(response.encode(), b"\x09\x00\x00\x00\x00\x00\x00\x12\x34\x56")
+ assert response.encode() == b"\x09\x00\x00\x00\x00\x00\x00\x12\x34\x56"
response.decode(b"\x09\x00\x00\x00\x12\x00\x12\x12\x34\x56")
- self.assertEqual(response.status, True)
- self.assertEqual(response.message_count, 0x12)
- self.assertEqual(response.event_count, 0x12)
- self.assertEqual(response.events, [0x12, 0x34, 0x56])
+ assert response.status
+ assert response.message_count == 0x12
+ assert response.event_count == 0x12
+ assert response.events == [0x12, 0x34, 0x56]
def test_report_slave_id_request(self):
"""Test report slave id request."""
@@ -112,7 +104,7 @@ def test_report_slave_id_request(self):
request = pymodbus_message.ReportSlaveIdRequest()
response = request.execute()
- self.assertEqual(response.identifier, expected_identity)
+ assert response.identifier == expected_identity
# Change to byte strings and test again (final result should be the same)
identity = {
@@ -130,7 +122,7 @@ def test_report_slave_id_request(self):
request = pymodbus_message.ReportSlaveIdRequest()
response = request.execute()
- self.assertEqual(response.identifier, expected_identity)
+ assert response.identifier == expected_identity
def test_report_slave_id(self):
"""Test report slave id."""
@@ -138,17 +130,17 @@ def test_report_slave_id(self):
dif.get.return_value = {}
request = pymodbus_message.ReportSlaveIdRequest()
request.decode(b"\x12")
- self.assertEqual(request.encode(), b"")
- self.assertEqual(request.execute().function_code, 0x11)
+ assert not request.encode()
+ assert request.execute().function_code == 0x11
response = pymodbus_message.ReportSlaveIdResponse(
request.execute().identifier, True
)
- self.assertEqual(response.encode(), b"\tPymodbus\xff")
+ assert response.encode() == b"\tPymodbus\xff"
response.decode(b"\x03\x12\x00")
- self.assertEqual(response.status, False)
- self.assertEqual(response.identifier, b"\x12\x00")
+ assert not response.status
+ assert response.identifier == b"\x12\x00"
response.status = False
- self.assertEqual(response.encode(), b"\x03\x12\x00\x00")
+ assert response.encode() == b"\x03\x12\x00\x00"
diff --git a/test/test_payload.py b/test/test_payload.py
index 93f64d929..272ad4a31 100644
--- a/test/test_payload.py
+++ b/test/test_payload.py
@@ -6,7 +6,7 @@
* PayloadBuilder
* PayloadDecoder
"""
-import unittest
+import pytest
from pymodbus.constants import Endian
from pymodbus.exceptions import ParameterException
@@ -18,35 +18,26 @@
# ---------------------------------------------------------------------------#
-class ModbusPayloadUtilityTests(unittest.TestCase):
+class TestPayloadUtility:
"""Modbus payload utility tests."""
- # ----------------------------------------------------------------------- #
- # Setup/TearDown
- # ----------------------------------------------------------------------- #
-
- def setUp(self):
- """Initialize the test environment and builds request/result encoding pairs."""
- self.little_endian_payload = (
- b"\x01\x02\x00\x03\x00\x00\x00\x04\x00\x00\x00\x00"
- b"\x00\x00\x00\xff\xfe\xff\xfd\xff\xff\xff\xfc\xff"
- b"\xff\xff\xff\xff\xff\xff\x00\x00\xa0\x3f\x00\x00"
- b"\x00\x00\x00\x00\x19\x40\x01\x00\x74\x65\x73\x74"
- b"\x11"
- )
-
- self.big_endian_payload = (
- b"\x01\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00"
- b"\x00\x00\x04\xff\xff\xfe\xff\xff\xff\xfd\xff\xff"
- b"\xff\xff\xff\xff\xff\xfc\x3f\xa0\x00\x00\x40\x19"
- b"\x00\x00\x00\x00\x00\x00\x00\x01\x74\x65\x73\x74"
- b"\x11"
- )
+ little_endian_payload = (
+ b"\x01\x02\x00\x03\x00\x00\x00\x04\x00\x00\x00\x00"
+ b"\x00\x00\x00\xff\xfe\xff\xfd\xff\xff\xff\xfc\xff"
+ b"\xff\xff\xff\xff\xff\xff\x00\x00\xa0\x3f\x00\x00"
+ b"\x00\x00\x00\x00\x19\x40\x01\x00\x74\x65\x73\x74"
+ b"\x11"
+ )
- self.bitstring = [True, False, False, False, True, False, False, False]
+ big_endian_payload = (
+ b"\x01\x00\x02\x00\x00\x00\x03\x00\x00\x00\x00\x00"
+ b"\x00\x00\x04\xff\xff\xfe\xff\xff\xff\xfd\xff\xff"
+ b"\xff\xff\xff\xff\xff\xfc\x3f\xa0\x00\x00\x40\x19"
+ b"\x00\x00\x00\x00\x00\x00\x00\x01\x74\x65\x73\x74"
+ b"\x11"
+ )
- def tearDown(self):
- """Clean up the test environment"""
+ bitstring = [True, False, False, False, True, False, False, False]
# ----------------------------------------------------------------------- #
# Payload Builder Tests
@@ -66,9 +57,9 @@ def test_little_endian_payload_builder(self):
builder.add_32bit_float(1.25)
builder.add_64bit_float(6.25)
builder.add_16bit_uint(1) # placeholder
- builder.add_string(b"test")
+ builder.add_string("test")
builder.add_bits(self.bitstring)
- self.assertEqual(self.little_endian_payload, builder.to_string())
+ assert self.little_endian_payload == builder.encode()
def test_big_endian_payload_builder(self):
"""Test basic bit message encoding/decoding"""
@@ -86,7 +77,7 @@ def test_big_endian_payload_builder(self):
builder.add_16bit_uint(1) # placeholder
builder.add_string("test")
builder.add_bits(self.bitstring)
- self.assertEqual(self.big_endian_payload, builder.to_string())
+ assert self.big_endian_payload == builder.encode()
def test_payload_builder_reset(self):
"""Test basic bit message encoding/decoding"""
@@ -95,11 +86,11 @@ def test_payload_builder_reset(self):
builder.add_8bit_uint(0x34)
builder.add_8bit_uint(0x56)
builder.add_8bit_uint(0x78)
- self.assertEqual(b"\x12\x34\x56\x78", builder.to_string())
- self.assertEqual([b"\x12\x34", b"\x56\x78"], builder.build())
+ assert builder.encode() == b"\x12\x34\x56\x78"
+ assert builder.build() == [b"\x12\x34", b"\x56\x78"]
builder.reset()
- self.assertEqual(b"", builder.to_string())
- self.assertEqual([], builder.build())
+ assert not builder.encode()
+ assert not builder.build()
def test_payload_builder_with_raw_payload(self):
"""Test basic bit message encoding/decoding"""
@@ -175,19 +166,19 @@ def test_payload_builder_with_raw_payload(self):
builder = BinaryPayloadBuilder(
[b"\x12", b"\x34", b"\x56", b"\x78"], repack=True
)
- self.assertEqual(b"\x12\x34\x56\x78", builder.to_string())
- self.assertEqual([13330, 30806], builder.to_registers())
+ assert builder.encode() == b"\x12\x34\x56\x78"
+ assert builder.to_registers() == [13330, 30806]
coils = builder.to_coils()
- self.assertEqual(_coils1, coils)
+ assert _coils1 == coils
builder = BinaryPayloadBuilder(
[b"\x12", b"\x34", b"\x56", b"\x78"], byteorder=Endian.Big
)
- self.assertEqual(b"\x12\x34\x56\x78", builder.to_string())
- self.assertEqual([4660, 22136], builder.to_registers())
- self.assertEqual("\x12\x34\x56\x78", str(builder))
+ assert builder.encode() == b"\x12\x34\x56\x78"
+ assert builder.to_registers() == [4660, 22136]
+ assert str(builder) == "\x12\x34\x56\x78"
coils = builder.to_coils()
- self.assertEqual(_coils2, coils)
+ assert _coils2 == coils
# ----------------------------------------------------------------------- #
# Payload Decoder Tests
@@ -198,71 +189,68 @@ def test_little_endian_payload_decoder(self):
decoder = BinaryPayloadDecoder(
self.little_endian_payload, byteorder=Endian.Little, wordorder=Endian.Little
)
- self.assertEqual(1, decoder.decode_8bit_uint())
- self.assertEqual(2, decoder.decode_16bit_uint())
- self.assertEqual(3, decoder.decode_32bit_uint())
- self.assertEqual(4, decoder.decode_64bit_uint())
- self.assertEqual(-1, decoder.decode_8bit_int())
- self.assertEqual(-2, decoder.decode_16bit_int())
- self.assertEqual(-3, decoder.decode_32bit_int())
- self.assertEqual(-4, decoder.decode_64bit_int())
- self.assertEqual(1.25, decoder.decode_32bit_float())
- self.assertEqual(6.25, decoder.decode_64bit_float())
- self.assertEqual(None, decoder.skip_bytes(2))
- self.assertEqual("test", decoder.decode_string(4).decode())
- self.assertEqual(self.bitstring, decoder.decode_bits())
+ assert decoder.decode_8bit_uint() == 1
+ assert decoder.decode_16bit_uint() == 2
+ assert decoder.decode_32bit_uint() == 3
+ assert decoder.decode_64bit_uint() == 4
+ assert decoder.decode_8bit_int() == -1
+ assert decoder.decode_16bit_int() == -2
+ assert decoder.decode_32bit_int() == -3
+ assert decoder.decode_64bit_int() == -4
+ assert decoder.decode_32bit_float() == 1.25
+ assert decoder.decode_64bit_float() == 6.25
+ assert not decoder.skip_bytes(2)
+ assert decoder.decode_string(4).decode() == "test"
+ assert self.bitstring == decoder.decode_bits()
def test_big_endian_payload_decoder(self):
"""Test basic bit message encoding/decoding"""
decoder = BinaryPayloadDecoder(self.big_endian_payload, byteorder=Endian.Big)
- self.assertEqual(1, decoder.decode_8bit_uint())
- self.assertEqual(2, decoder.decode_16bit_uint())
- self.assertEqual(3, decoder.decode_32bit_uint())
- self.assertEqual(4, decoder.decode_64bit_uint())
- self.assertEqual(-1, decoder.decode_8bit_int())
- self.assertEqual(-2, decoder.decode_16bit_int())
- self.assertEqual(-3, decoder.decode_32bit_int())
- self.assertEqual(-4, decoder.decode_64bit_int())
- self.assertEqual(1.25, decoder.decode_32bit_float())
- self.assertEqual(6.25, decoder.decode_64bit_float())
- self.assertEqual(None, decoder.skip_bytes(2))
- self.assertEqual(b"test", decoder.decode_string(4))
- self.assertEqual(self.bitstring, decoder.decode_bits())
+ assert decoder.decode_8bit_uint() == 1
+ assert decoder.decode_16bit_uint() == 2
+ assert decoder.decode_32bit_uint() == 3
+ assert decoder.decode_64bit_uint() == 4
+ assert decoder.decode_8bit_int() == -1
+ assert decoder.decode_16bit_int() == -2
+ assert decoder.decode_32bit_int() == -3
+ assert decoder.decode_64bit_int() == -4
+ assert decoder.decode_32bit_float() == 1.25
+ assert decoder.decode_64bit_float() == 6.25
+ assert not decoder.skip_bytes(2)
+ assert decoder.decode_string(4) == b"test"
+ assert self.bitstring == decoder.decode_bits()
def test_payload_decoder_reset(self):
"""Test the payload decoder reset functionality"""
decoder = BinaryPayloadDecoder(b"\x12\x34")
- self.assertEqual(0x12, decoder.decode_8bit_uint())
- self.assertEqual(0x34, decoder.decode_8bit_uint())
+ assert decoder.decode_8bit_uint() == 0x12
+ assert decoder.decode_8bit_uint() == 0x34
decoder.reset()
- self.assertEqual(0x3412, decoder.decode_16bit_uint())
+ assert decoder.decode_16bit_uint() == 0x3412
def test_payload_decoder_register_factory(self):
"""Test the payload decoder reset functionality"""
payload = [1, 2, 3, 4]
decoder = BinaryPayloadDecoder.fromRegisters(payload, byteorder=Endian.Little)
encoded = b"\x00\x01\x00\x02\x00\x03\x00\x04"
- self.assertEqual(encoded, decoder.decode_string(8))
+ assert encoded == decoder.decode_string(8)
decoder = BinaryPayloadDecoder.fromRegisters(payload, byteorder=Endian.Big)
encoded = b"\x00\x01\x00\x02\x00\x03\x00\x04"
- self.assertEqual(encoded, decoder.decode_string(8))
-
- self.assertRaises(
- ParameterException, lambda: BinaryPayloadDecoder.fromRegisters("abcd")
- )
+ assert encoded == decoder.decode_string(8)
+ with pytest.raises(ParameterException):
+ BinaryPayloadDecoder.fromRegisters("abcd")
def test_payload_decoder_coil_factory(self):
"""Test the payload decoder reset functionality"""
payload = [1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1]
decoder = BinaryPayloadDecoder.fromCoils(payload, byteorder=Endian.Little)
encoded = b"\x88\x11"
- self.assertEqual(encoded, decoder.decode_string(2))
+ assert encoded == decoder.decode_string(2)
decoder = BinaryPayloadDecoder.fromCoils(payload, byteorder=Endian.Big)
encoded = b"\x88\x11"
- self.assertEqual(encoded, decoder.decode_string(2))
+ assert encoded == decoder.decode_string(2)
- self.assertRaises(
- ParameterException, lambda: BinaryPayloadDecoder.fromCoils("abcd")
- )
+ with pytest.raises(ParameterException):
+ BinaryPayloadDecoder.fromCoils("abcd")
diff --git a/test/test_pdu.py b/test/test_pdu.py
index 8512c6eca..69021b09d 100644
--- a/test/test_pdu.py
+++ b/test/test_pdu.py
@@ -1,5 +1,5 @@
"""Test pdu."""
-import unittest
+import pytest
from pymodbus.exceptions import NotImplementedException
from pymodbus.pdu import (
@@ -11,31 +11,25 @@
)
-class SimplePduTest(unittest.TestCase):
+class TestPdu:
"""Unittest for the pymod.pdu module."""
- def setUp(self):
- """Initialize the test environment"""
- self.bad_requests = (
- ModbusRequest(),
- ModbusResponse(),
- )
- self.illegal = IllegalFunctionRequest(1)
- self.exception = ExceptionResponse(1, 1)
-
- def tearDown(self):
- """Clean up the test environment"""
- del self.bad_requests
- del self.illegal
- del self.exception
+ bad_requests = (
+ ModbusRequest(),
+ ModbusResponse(),
+ )
+ illegal = IllegalFunctionRequest(1)
+ exception = ExceptionResponse(1, 1)
def test_not_impelmented(self):
"""Test a base classes for not implemented functions"""
for request in self.bad_requests:
- self.assertRaises(NotImplementedException, request.encode)
+ with pytest.raises(NotImplementedException):
+ request.encode()
for request in self.bad_requests:
- self.assertRaises(NotImplementedException, request.decode, None)
+ with pytest.raises(NotImplementedException):
+ request.decode(None)
def test_error_methods(self):
"""Test all error methods"""
@@ -44,8 +38,8 @@ def test_error_methods(self):
result = self.exception.encode()
self.exception.decode(result)
- self.assertEqual(result, b"\x01")
- self.assertEqual(self.exception.exception_code, 1)
+ assert result == b"\x01"
+ assert self.exception.exception_code == 1
def test_request_exception_factory(self):
"""Test all error methods"""
@@ -54,37 +48,35 @@ def test_request_exception_factory(self):
errors = {ModbusExceptions.decode(c): c for c in range(1, 20)}
for error, code in iter(errors.items()):
result = request.doException(code)
- self.assertEqual(str(result), f"Exception Response(129, 1, {error})")
+ assert str(result) == f"Exception Response(129, 1, {error})"
def test_calculate_rtu_frame_size(self):
"""Test the calculation of Modbus/RTU frame sizes"""
- self.assertRaises(
- NotImplementedException, ModbusRequest.calculateRtuFrameSize, b""
- )
+ with pytest.raises(NotImplementedException):
+ ModbusRequest.calculateRtuFrameSize(b"")
ModbusRequest._rtu_frame_size = 5 # pylint: disable=protected-access
- self.assertEqual(ModbusRequest.calculateRtuFrameSize(b""), 5)
+ assert ModbusRequest.calculateRtuFrameSize(b"") == 5
del ModbusRequest._rtu_frame_size
ModbusRequest._rtu_byte_count_pos = 2 # pylint: disable=protected-access
- self.assertEqual(
+ assert (
ModbusRequest.calculateRtuFrameSize(
b"\x11\x01\x05\xcd\x6b\xb2\x0e\x1b\x45\xe6"
- ),
- 0x05 + 5,
+ )
+ == 0x05 + 5
)
del ModbusRequest._rtu_byte_count_pos
- self.assertRaises(
- NotImplementedException, ModbusResponse.calculateRtuFrameSize, b""
- )
+ with pytest.raises(NotImplementedException):
+ ModbusResponse.calculateRtuFrameSize(b"")
ModbusResponse._rtu_frame_size = 12 # pylint: disable=protected-access
- self.assertEqual(ModbusResponse.calculateRtuFrameSize(b""), 12)
+ assert ModbusResponse.calculateRtuFrameSize(b"") == 12
del ModbusResponse._rtu_frame_size
ModbusResponse._rtu_byte_count_pos = 2 # pylint: disable=protected-access
- self.assertEqual(
+ assert (
ModbusResponse.calculateRtuFrameSize(
b"\x11\x01\x05\xcd\x6b\xb2\x0e\x1b\x45\xe6"
- ),
- 0x05 + 5,
+ )
+ == 0x05 + 5
)
del ModbusResponse._rtu_byte_count_pos
diff --git a/test/test_register_read_messages.py b/test/test_register_read_messages.py
index 61c8df400..b9e19967a 100644
--- a/test/test_register_read_messages.py
+++ b/test/test_register_read_messages.py
@@ -1,5 +1,4 @@
"""Test register read messages."""
-import unittest
from test.conftest import FakeList, MockContext
from pymodbus.pdu import ModbusExceptions
@@ -22,7 +21,7 @@
# ---------------------------------------------------------------------------#
-class ReadRegisterMessagesTest(unittest.TestCase):
+class TestReadRegisterMessages:
"""Register Message Test Fixture.
This fixture tests the functionality of all the
@@ -32,7 +31,12 @@ class ReadRegisterMessagesTest(unittest.TestCase):
* Read Holding Registers
"""
- def setUp(self):
+ value = None
+ values = None
+ request_read = None
+ response_read = None
+
+ def setup_method(self):
"""Initialize the test environment and builds request/result encoding pairs."""
arguments = {
"read_address": 1,
@@ -64,26 +68,21 @@ def setUp(self):
ReadWriteMultipleRegistersResponse(self.values): TEST_MESSAGE,
}
- def tearDown(self):
- """Clean up the test environment."""
- del self.request_read
- del self.response_read
-
def test_read_register_response_base(self):
"""Test read register response."""
response = ReadRegistersResponseBase(list(range(10)))
for index in range(10):
- self.assertEqual(response.getRegister(index), index)
+ assert response.getRegister(index) == index
def test_register_read_requests(self):
"""Test register read requests."""
for request, response in iter(self.request_read.items()):
- self.assertEqual(request.encode(), response)
+ assert request.encode() == response
def test_register_read_responses(self):
"""Test register read response."""
for request, response in iter(self.response_read.items()):
- self.assertEqual(request.encode(), response)
+ assert request.encode() == response
def test_register_read_response_decode(self):
"""Test register read response."""
@@ -100,7 +99,7 @@ def test_register_read_response_decode(self):
for packet, register in zip(values, registers):
request, response = packet
request.decode(response)
- self.assertEqual(request.registers, register)
+ assert request.registers == register
def test_register_read_requests_count_errors(self):
"""This tests that the register request messages.
@@ -120,7 +119,7 @@ def test_register_read_requests_count_errors(self):
]
for request in requests:
result = request.execute(None)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
def test_register_read_requests_validate_errors(self):
"""This tests that the register request messages.
@@ -136,7 +135,7 @@ def test_register_read_requests_validate_errors(self):
]
for request in requests:
result = request.execute(context)
- self.assertEqual(ModbusExceptions.IllegalAddress, result.exception_code)
+ assert ModbusExceptions.IllegalAddress == result.exception_code
def test_register_read_requests_execute(self):
"""This tests that the register request messages.
@@ -150,7 +149,7 @@ def test_register_read_requests_execute(self):
]
for request in requests:
response = request.execute(context)
- self.assertEqual(request.function_code, response.function_code)
+ assert request.function_code == response.function_code
def test_read_write_multiple_registers_request(self):
"""Test read/write multiple registers."""
@@ -159,7 +158,7 @@ def test_read_write_multiple_registers_request(self):
read_address=1, read_count=10, write_address=1, write_registers=[0x00]
)
response = request.execute(context)
- self.assertEqual(request.function_code, response.function_code)
+ assert request.function_code == response.function_code
def test_read_write_multiple_registers_validate(self):
"""Test read/write multiple registers."""
@@ -169,15 +168,15 @@ def test_read_write_multiple_registers_validate(self):
read_address=1, read_count=10, write_address=2, write_registers=[0x00]
)
response = request.execute(context)
- self.assertEqual(response.exception_code, ModbusExceptions.IllegalAddress)
+ assert response.exception_code == ModbusExceptions.IllegalAddress
context.validate = lambda f, a, c: a == 2
response = request.execute(context)
- self.assertEqual(response.exception_code, ModbusExceptions.IllegalAddress)
+ assert response.exception_code == ModbusExceptions.IllegalAddress
request.write_byte_count = 0x100
response = request.execute(context)
- self.assertEqual(response.exception_code, ModbusExceptions.IllegalValue)
+ assert response.exception_code == ModbusExceptions.IllegalValue
def test_read_write_multiple_registers_request_decode(self):
"""Test read/write multiple registers."""
@@ -187,16 +186,16 @@ def test_read_write_multiple_registers_request_decode(self):
if getattr(k, "function_code", 0) == 23
)
request.decode(response)
- self.assertEqual(request.read_address, 0x01)
- self.assertEqual(request.write_address, 0x01)
- self.assertEqual(request.read_count, 0x05)
- self.assertEqual(request.write_count, 0x05)
- self.assertEqual(request.write_byte_count, 0x0A)
- self.assertEqual(request.write_registers, [0x00] * 5)
+ assert request.read_address == 0x01
+ assert request.write_address == 0x01
+ assert request.read_count == 0x05
+ assert request.write_count == 0x05
+ assert request.write_byte_count == 0x0A
+ assert request.write_registers == [0x00] * 5
def test_serializing_to_string(self):
"""Test serializing to string."""
for request in iter(self.request_read.keys()):
- self.assertTrue(str(request) is not None)
+ assert str(request)
for request in iter(self.response_read.keys()):
- self.assertTrue(str(request) is not None)
+ assert str(request)
diff --git a/test/test_register_write_messages.py b/test/test_register_write_messages.py
index f19551046..b6e1c156f 100644
--- a/test/test_register_write_messages.py
+++ b/test/test_register_write_messages.py
@@ -1,5 +1,4 @@
"""Test register write messages."""
-import unittest
from test.conftest import MockContext, MockLastValuesContext
from pymodbus.payload import BinaryPayloadBuilder, Endian
@@ -19,7 +18,7 @@
# ---------------------------------------------------------------------------#
-class WriteRegisterMessagesTest(unittest.TestCase):
+class TestWriteRegisterMessages:
"""Register Message Test Fixture.
This fixture tests the functionality of all the
@@ -29,7 +28,13 @@ class WriteRegisterMessagesTest(unittest.TestCase):
* Read Holding Registers
"""
- def setUp(self):
+ value = None
+ values = None
+ builder = None
+ write = None
+ payload = None
+
+ def setup_method(self):
"""Initialize the test environment and builds request/result encoding pairs."""
self.value = 0xABCD
self.values = [0xA, 0xB, 0xC]
@@ -52,14 +57,10 @@ def setUp(self):
): b"\x00\x01\x00\x01\x02\x12\x34",
}
- def tearDown(self):
- """Clean up the test environment"""
- del self.write
-
def test_register_write_requests_encode(self):
"""Test register write requests encode."""
for request, response in iter(self.write.items()):
- self.assertEqual(request.encode(), response)
+ assert request.encode() == response
def test_register_write_requests_decode(self):
"""Test register write requests decode."""
@@ -71,56 +72,56 @@ def test_register_write_requests_decode(self):
for packet, address in zip(values, addresses):
request, response = packet
request.decode(response)
- self.assertEqual(request.address, address)
+ assert request.address == address
def test_invalid_write_multiple_registers_request(self):
"""Test invalid write multiple registers request."""
request = WriteMultipleRegistersRequest(0, None)
- self.assertEqual(request.values, [])
+ assert not request.values
def test_serializing_to_string(self):
"""Test serializing to string."""
for request in iter(self.write.keys()):
- self.assertTrue(str(request) is not None)
+ assert str(request)
def test_write_single_register_request(self):
"""Test write single register request."""
context = MockContext()
request = WriteSingleRegisterRequest(0x00, 0xF0000)
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalValue)
+ assert result.exception_code == ModbusExceptions.IllegalValue
request.value = 0x00FF
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalAddress)
+ assert result.exception_code == ModbusExceptions.IllegalAddress
context.valid = True
result = request.execute(context)
- self.assertEqual(result.function_code, request.function_code)
+ assert result.function_code == request.function_code
def test_write_multiple_register_request(self):
"""Test write multiple register request."""
context = MockContext()
request = WriteMultipleRegistersRequest(0x00, [0x00] * 10)
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalAddress)
+ assert result.exception_code == ModbusExceptions.IllegalAddress
request.count = 0x05 # bytecode != code * 2
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalValue)
+ assert result.exception_code == ModbusExceptions.IllegalValue
request.count = 0x800 # outside of range
result = request.execute(context)
- self.assertEqual(result.exception_code, ModbusExceptions.IllegalValue)
+ assert result.exception_code == ModbusExceptions.IllegalValue
context.valid = True
request = WriteMultipleRegistersRequest(0x00, [0x00] * 10)
result = request.execute(context)
- self.assertEqual(result.function_code, request.function_code)
+ assert result.function_code == request.function_code
request = WriteMultipleRegistersRequest(0x00, 0x00)
result = request.execute(context)
- self.assertEqual(result.function_code, request.function_code)
+ assert result.function_code == request.function_code
# -----------------------------------------------------------------------#
# Mask Write Register Request
@@ -130,16 +131,16 @@ def test_mask_write_register_request_encode(self):
"""Test basic bit message encoding/decoding"""
handle = MaskWriteRegisterRequest(0x0000, 0x0101, 0x1010)
result = handle.encode()
- self.assertEqual(result, b"\x00\x00\x01\x01\x10\x10")
+ assert result == b"\x00\x00\x01\x01\x10\x10"
def test_mask_write_register_request_decode(self):
"""Test basic bit message encoding/decoding"""
request = b"\x00\x04\x00\xf2\x00\x25"
handle = MaskWriteRegisterRequest()
handle.decode(request)
- self.assertEqual(handle.address, 0x0004)
- self.assertEqual(handle.and_mask, 0x00F2)
- self.assertEqual(handle.or_mask, 0x0025)
+ assert handle.address == 0x0004
+ assert handle.and_mask == 0x00F2
+ assert handle.or_mask == 0x0025
def test_mask_write_register_request_execute(self):
"""Test write register request valid execution"""
@@ -152,23 +153,23 @@ def test_mask_write_register_request_execute(self):
context = MockLastValuesContext(valid=True, default=0xAA55)
handle = MaskWriteRegisterRequest(0x0000, 0x0F0F, 0x00FF)
result = handle.execute(context)
- self.assertTrue(isinstance(result, MaskWriteRegisterResponse))
- self.assertEqual([0x0AF5], context.last_values)
+ assert isinstance(result, MaskWriteRegisterResponse)
+ assert context.last_values == [0x0AF5]
def test_mask_write_register_request_invalid_execute(self):
"""Test write register request execute with invalid data"""
context = MockContext(valid=False, default=0x0000)
handle = MaskWriteRegisterRequest(0x0000, -1, 0x1010)
result = handle.execute(context)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
handle = MaskWriteRegisterRequest(0x0000, 0x0101, -1)
result = handle.execute(context)
- self.assertEqual(ModbusExceptions.IllegalValue, result.exception_code)
+ assert ModbusExceptions.IllegalValue == result.exception_code
handle = MaskWriteRegisterRequest(0x0000, 0x0101, 0x1010)
result = handle.execute(context)
- self.assertEqual(ModbusExceptions.IllegalAddress, result.exception_code)
+ assert ModbusExceptions.IllegalAddress == result.exception_code
# -----------------------------------------------------------------------#
# Mask Write Register Response
@@ -178,13 +179,13 @@ def test_mask_write_register_response_encode(self):
"""Test basic bit message encoding/decoding"""
handle = MaskWriteRegisterResponse(0x0000, 0x0101, 0x1010)
result = handle.encode()
- self.assertEqual(result, b"\x00\x00\x01\x01\x10\x10")
+ assert result == b"\x00\x00\x01\x01\x10\x10"
def test_mask_write_register_response_decode(self):
"""Test basic bit message encoding/decoding"""
request = b"\x00\x04\x00\xf2\x00\x25"
handle = MaskWriteRegisterResponse()
handle.decode(request)
- self.assertEqual(handle.address, 0x0004)
- self.assertEqual(handle.and_mask, 0x00F2)
- self.assertEqual(handle.or_mask, 0x0025)
+ assert handle.address == 0x0004
+ assert handle.and_mask == 0x00F2
+ assert handle.or_mask == 0x0025
diff --git a/test/test_remote_datastore.py b/test/test_remote_datastore.py
index b4ed047de..ec155d925 100644
--- a/test/test_remote_datastore.py
+++ b/test/test_remote_datastore.py
@@ -1,7 +1,8 @@
"""Test remote datastore."""
-import unittest
from unittest import mock
+import pytest
+
from pymodbus.bit_read_message import ReadCoilsResponse
from pymodbus.bit_write_message import WriteMultipleCoilsResponse
from pymodbus.datastore.remote import RemoteSlaveContext
@@ -10,17 +11,15 @@
from pymodbus.register_read_message import ReadInputRegistersResponse
-class RemoteModbusDataStoreTest(unittest.TestCase):
+class TestRemoteDataStore:
"""Unittest for the pymodbus.datastore.remote module."""
def test_remote_slave_context(self):
"""Test a modbus remote slave context"""
context = RemoteSlaveContext(None)
- self.assertNotEqual(str(context), None)
- self.assertRaises(
- NotImplementedException,
- lambda: context.reset(), # pylint: disable=unnecessary-lambda
- )
+ assert str(context)
+ with pytest.raises(NotImplementedException):
+ context.reset()
def test_remote_slave_set_values(self):
"""Test setting values against a remote slave context"""
@@ -40,15 +39,15 @@ def test_remote_slave_get_values(self):
context = RemoteSlaveContext(client)
context.validate(1, 0, 10)
result = context.getValues(1, 0, 10)
- self.assertEqual(result, [1] * 10)
+ assert result == [1] * 10
context.validate(4, 0, 10)
result = context.getValues(4, 0, 10)
- self.assertEqual(result, [10] * 10)
+ assert result == [10] * 10
context.validate(3, 0, 10)
result = context.getValues(3, 0, 10)
- self.assertNotEqual(result, [10] * 10)
+ assert result != [10] * 10
def test_remote_slave_validate_values(self):
"""Test validating against a remote slave context"""
@@ -59,10 +58,10 @@ def test_remote_slave_validate_values(self):
context = RemoteSlaveContext(client)
result = context.validate(1, 0, 10)
- self.assertTrue(result)
+ assert result
result = context.validate(4, 0, 10)
- self.assertTrue(result)
+ assert result
result = context.validate(3, 0, 10)
- self.assertFalse(result)
+ assert not result
diff --git a/test/test_repl_client.py b/test/test_repl_client.py
index 6a51acd48..5f39650f1 100755
--- a/test/test_repl_client.py
+++ b/test/test_repl_client.py
@@ -1,4 +1,6 @@
"""Test client sync."""
+from contextlib import suppress
+
from pymodbus.repl.client.main import _process_args
from pymodbus.server.reactive.default_config import DEFAULT_CONFIG
@@ -32,17 +34,11 @@ def test_repl_client_process_args():
resp = _process_args(["address=0b11", "value=0x10"], False)
assert resp == ({"address": 3, "value": 16}, True)
- try:
+ with suppress(ValueError):
resp = _process_args(["address=0xhj", "value=0x10"], False)
- except ValueError:
- pass
- try:
+ with suppress(ValueError):
resp = _process_args(["address=11ah", "value=0x10"], False)
- except ValueError:
- pass
- try:
+ with suppress(ValueError):
resp = _process_args(["address=0b12", "value=0x10"], False)
- except ValueError:
- pass
diff --git a/test/test_server_asyncio.py b/test/test_server_asyncio.py
index 9e279cbdd..2f2be1643 100755
--- a/test/test_server_asyncio.py
+++ b/test/test_server_asyncio.py
@@ -2,9 +2,9 @@
import asyncio
import logging
import ssl
-import unittest
from asyncio import CancelledError
-from unittest.mock import AsyncMock, Mock, patch
+from contextlib import suppress
+from unittest import mock
import pytest
@@ -88,20 +88,26 @@ def clear(cls):
BasicClient.my_protocol = None
-class AsyncioServerTest(
- unittest.IsolatedAsyncioTestCase
-): # pylint: disable=too-many-public-methods
+class TestAsyncioServer: # pylint: disable=too-many-public-methods
"""Unittest for the pymodbus.server.asyncio module.
- The scope of this unit test is the life-cycle management of the network
+ The scope of this test is the life-cycle management of the network
connections and server objects.
- This unittest suite does not attempt to test any of the underlying protocol details
+ This test suite does not attempt to test any of the underlying protocol details
"""
- def __init__(self, name):
- """Initialize."""
- super().__init__(name)
+ server = None
+ task = None
+ loop = None
+ store = None
+ context = None
+ identity = None
+
+ @pytest.fixture(autouse=True)
+ async def _setup_teardown(self):
+ """Initialize the test environment by setting up a dummy store and context."""
+ self.loop = asyncio.get_running_loop()
self.store = ModbusSlaveContext(
di=ModbusSequentialDataBlock(0, [17] * 100),
co=ModbusSequentialDataBlock(0, [17] * 100),
@@ -112,22 +118,9 @@ def __init__(self, name):
self.identity = ModbusDeviceIdentification(
info_name={"VendorName": "VendorName"}
)
- self.server = None
- self.task = None
- self.loop = None
-
- # -----------------------------------------------------------------------#
- # Setup/TearDown
- # -----------------------------------------------------------------------#
- def setUp(self):
- """Initialize the test environment by setting up a dummy store and context."""
-
- async def asyncSetUp(self):
- """Initialize the test environment by setting up a dummy store and context."""
- self.loop = asyncio.get_running_loop()
+ yield
- async def asyncTearDown(self):
- """Clean up the test environment"""
+ # teardown
if self.server is not None:
await self.server.server_close()
self.server = None
@@ -135,32 +128,23 @@ async def asyncTearDown(self):
await asyncio.sleep(0.1)
if not self.task.cancelled():
self.task.cancel()
- try:
+ with suppress(CancelledError):
await self.task
- except CancelledError:
- pass
self.task = None
- self.context = ModbusServerContext(slaves=self.store, single=True)
BasicClient.clear()
- def tearDown(self):
- """Clean up the test environment."""
-
def handle_task(self, result):
"""Handle task exit."""
- try:
+ with suppress(CancelledError):
result = result.result()
- except CancelledError:
- pass
async def start_server(
- self, do_forever=True, do_defer=True, do_tls=False, do_udp=False, do_ident=False
+ self, do_forever=True, do_tls=False, do_udp=False, do_ident=False
):
"""Handle setup and control of tcp server."""
args = {
"context": self.context,
"address": SERV_ADDR,
- "defer_start": do_defer,
}
if do_ident:
args["identity"] = self.identity
@@ -176,24 +160,24 @@ async def start_server(
self.server = ModbusTcpServer(
self.context, ModbusSocketFramer, self.identity, SERV_ADDR
)
- self.assertIsNotNone(self.server)
+ assert self.server
if do_forever:
self.task = asyncio.create_task(self.server.serve_forever())
self.task.add_done_callback(self.handle_task)
- self.assertFalse(self.task.cancelled())
+ assert not self.task.cancelled()
await asyncio.wait_for(self.server.serving, timeout=0.1)
if not do_udp:
- self.assertIsNotNone(self.server.server)
+ assert self.server.server
elif not do_udp: # pylint: disable=confusing-consecutive-elif
- self.assertIsNone(self.server.server)
- self.assertEqual(self.server.control.Identity.VendorName, "VendorName")
+ assert not self.server.server
+ assert self.server.control.Identity.VendorName == "VendorName"
await asyncio.sleep(0.1)
async def connect_server(self):
"""Handle connect to server"""
- BasicClient.connected = self.loop.create_future()
- BasicClient.done = self.loop.create_future()
- BasicClient.eof = self.loop.create_future()
+ BasicClient.connected = asyncio.Future()
+ BasicClient.done = asyncio.Future()
+ BasicClient.eof = asyncio.Future()
random_port = self.server.server.sockets[0].getsockname()[
1
] # get the random server port
@@ -221,39 +205,39 @@ async def test_async_start_server(self):
async def test_async_tcp_server_serve_forever_twice(self):
"""Call on serve_forever() twice should result in a runtime error"""
await self.start_server()
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
await self.server.serve_forever()
async def test_async_tcp_server_receive_data(self):
"""Test data sent on socket is received by internals - doesn't not process data"""
BasicClient.data = b"\x01\x00\x00\x00\x00\x06\x01\x03\x00\x00\x00\x19"
await self.start_server()
- with patch(
+ with mock.patch(
"pymodbus.transaction.ModbusSocketFramer.processIncomingPacket",
- new_callable=Mock,
+ new_callable=mock.Mock,
) as process:
await self.connect_server()
process.assert_called_once()
- self.assertTrue(process.call_args[1]["data"] == BasicClient.data)
+ assert process.call_args[1]["data"] == BasicClient.data
async def test_async_tcp_server_roundtrip(self):
"""Test sending and receiving data on tcp socket"""
expected_response = b"\x01\x00\x00\x00\x00\x05\x01\x03\x02\x00\x11"
- BasicClient.data = TEST_DATA # unit 1, read register
+ BasicClient.data = TEST_DATA # slave 1, read register
await self.start_server()
await self.connect_server()
await asyncio.wait_for(BasicClient.done, timeout=0.1)
- self.assertEqual(BasicClient.received_data, expected_response)
+ assert BasicClient.received_data, expected_response
async def test_async_tcp_server_connection_lost(self):
"""Test tcp stream interruption"""
await self.start_server()
await self.connect_server()
- self.assertEqual(len(self.server.active_connections), 1)
+ assert len(self.server.active_connections), 1
BasicClient.transport.close()
await asyncio.sleep(0.2) # so we have to wait a bit
- self.assertFalse(self.server.active_connections)
+ assert not self.server.active_connections
async def test_async_tcp_server_close_active_connection(self):
"""Test server_close() while there are active TCP connections"""
@@ -266,22 +250,22 @@ async def test_async_tcp_server_close_active_connection(self):
await self.server.server_close()
async def test_async_tcp_server_no_slave(self):
- """Test unknown slave unit exception"""
+ """Test unknown slave exception"""
self.context = ModbusServerContext(
slaves={0x01: self.store, 0x02: self.store}, single=False
)
BasicClient.data = b"\x01\x00\x00\x00\x00\x06\x05\x03\x00\x00\x00\x01"
await self.start_server()
await self.connect_server()
- self.assertFalse(BasicClient.eof.done())
- self.server.server_close()
+ assert not BasicClient.eof.done()
+ await self.server.server_close()
self.server = None
async def test_async_tcp_server_modbus_error(self):
"""Test sending garbage data on a TCP socket should drop the connection"""
BasicClient.data = TEST_DATA
await self.start_server()
- with patch(
+ with mock.patch(
"pymodbus.register_read_message.ReadHoldingRegistersRequest.execute",
side_effect=NoSuchSlaveException,
):
@@ -293,31 +277,30 @@ async def test_async_tcp_server_modbus_error(self):
# -----------------------------------------------------------------------#
async def test_async_start_tls_server_no_loop(self):
"""Test that the modbus tls asyncio server starts correctly"""
- with patch.object(ssl.SSLContext, "load_cert_chain"):
+ with mock.patch.object(ssl.SSLContext, "load_cert_chain"):
await self.start_server(do_tls=True, do_forever=False, do_ident=True)
- self.assertEqual(self.server.control.Identity.VendorName, "VendorName")
- self.assertIsNotNone(self.server.sslctx)
+ assert self.server.control.Identity.VendorName == "VendorName"
+ assert self.server.sslctx
async def test_async_start_tls_server(self):
"""Test that the modbus tls asyncio server starts correctly"""
- with patch.object(ssl.SSLContext, "load_cert_chain"):
+ with mock.patch.object(ssl.SSLContext, "load_cert_chain"):
await self.start_server(do_tls=True, do_ident=True)
- self.assertEqual(self.server.control.Identity.VendorName, "VendorName")
- self.assertIsNotNone(self.server.sslctx)
+ assert self.server.control.Identity.VendorName == "VendorName"
+ assert self.server.sslctx
async def test_async_tls_server_serve_forever(self):
"""Test StartAsyncTcpServer serve_forever() method"""
- with patch(
- "asyncio.base_events.Server.serve_forever", new_callable=AsyncMock
- ) as serve:
- with patch.object(ssl.SSLContext, "load_cert_chain"):
- await self.start_server(do_tls=True, do_forever=False)
- await self.server.serve_forever()
- serve.assert_awaited()
+ with mock.patch(
+ "asyncio.base_events.Server.serve_forever", new_callable=mock.AsyncMock
+ ) as serve, mock.patch.object(ssl.SSLContext, "load_cert_chain"):
+ await self.start_server(do_tls=True, do_forever=False)
+ await self.server.serve_forever()
+ serve.assert_awaited()
async def test_async_tls_server_serve_forever_twice(self):
"""Call on serve_forever() twice should result in a runtime error"""
- with patch.object(ssl.SSLContext, "load_cert_chain"):
+ with mock.patch.object(ssl.SSLContext, "load_cert_chain"):
await self.start_server(do_tls=True)
with pytest.raises(RuntimeError):
await self.server.serve_forever()
@@ -329,19 +312,19 @@ async def test_async_tls_server_serve_forever_twice(self):
async def test_async_start_udp_server_no_loop(self):
"""Test that the modbus udp asyncio server starts correctly"""
await self.start_server(do_udp=True, do_forever=False, do_ident=True)
- self.assertEqual(self.server.control.Identity.VendorName, "VendorName")
- self.assertIsNone(self.server.protocol)
+ assert self.server.control.Identity.VendorName == "VendorName"
+ assert not self.server.protocol
async def test_async_start_udp_server(self):
"""Test that the modbus udp asyncio server starts correctly"""
await self.start_server(do_udp=True, do_ident=True)
- self.assertEqual(self.server.control.Identity.VendorName, "VendorName")
- self.assertFalse(self.server.protocol is None)
+ assert self.server.control.Identity.VendorName == "VendorName"
+ assert self.server.protocol
async def test_async_udp_server_serve_forever_start(self):
"""Test StartAsyncUdpServer serve_forever() method"""
- with patch(
- "asyncio.base_events.Server.serve_forever", new_callable=AsyncMock
+ with mock.patch(
+ "asyncio.base_events.Server.serve_forever", new_callable=mock.AsyncMock
) as serve:
await self.start_server(do_forever=False, do_ident=True)
await self.server.serve_forever()
@@ -350,32 +333,30 @@ async def test_async_udp_server_serve_forever_start(self):
async def test_async_udp_server_serve_forever_close(self):
"""Test StarAsyncUdpServer serve_forever() method"""
await self.start_server(do_udp=True)
- self.assertTrue(asyncio.isfuture(self.server.on_connection_terminated))
- self.assertFalse(self.server.on_connection_terminated.done())
-
+ assert asyncio.isfuture(self.server.on_connection_terminated)
+ assert not self.server.on_connection_terminated.done()
await self.server.server_close()
- # TBD self.assertTrue(self.server.is_closing())
self.server = None
async def test_async_udp_server_serve_forever_twice(self):
"""Call on serve_forever() twice should result in a runtime error"""
await self.start_server(do_udp=True, do_ident=True)
- with self.assertRaises(RuntimeError):
+ with pytest.raises(RuntimeError):
await self.server.serve_forever()
@pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows have a timeout problem.")
async def test_async_udp_server_receive_data(self):
"""Test that the sending data on datagram socket gets data pushed to framer"""
await self.start_server(do_udp=True)
- with patch(
+ with mock.patch(
"pymodbus.transaction.ModbusSocketFramer.processIncomingPacket",
- new_callable=Mock,
+ new_callable=mock.Mock,
) as process:
self.server.endpoint.datagram_received(data=b"12345", addr=(SERV_IP, 12345))
await asyncio.sleep(0.1)
process.seal()
process.assert_called_once()
- self.assertTrue(process.call_args[1]["data"] == b"12345")
+ assert process.call_args[1]["data"] == b"12345"
async def test_async_udp_server_send_data(self):
"""Test that the modbus udp asyncio server correctly sends data outbound"""
@@ -384,7 +365,7 @@ async def test_async_udp_server_send_data(self):
random_port = self.server.protocol._sock.getsockname()[ # pylint: disable=protected-access
1
]
- received = self.server.endpoint.datagram_received = Mock(
+ received = self.server.endpoint.datagram_received = mock.Mock(
wraps=self.server.endpoint.datagram_received
)
await self.loop.create_datagram_endpoint(
@@ -392,15 +373,15 @@ async def test_async_udp_server_send_data(self):
)
await asyncio.sleep(0.1)
received.assert_called_once()
- self.assertEqual(received.call_args[0][0], BasicClient.dataTo)
+ assert received.call_args[0][0] == BasicClient.dataTo
await self.server.server_close()
self.server = None
async def test_async_udp_server_roundtrip(self):
"""Test sending and receiving data on udp socket"""
expected_response = b"\x01\x00\x00\x00\x00\x05\x01\x03\x02\x00\x11" # value of 17 as per context
- BasicClient.dataTo = TEST_DATA # unit 1, read register
- BasicClient.done = self.loop.create_future()
+ BasicClient.dataTo = TEST_DATA # slave 1, read register
+ BasicClient.done = asyncio.Future()
await self.start_server(do_udp=True)
random_port = self.server.protocol._sock.getsockname()[ # pylint: disable=protected-access
1
@@ -409,18 +390,18 @@ async def test_async_udp_server_roundtrip(self):
BasicClient, remote_addr=("127.0.0.1", random_port)
)
await asyncio.wait_for(BasicClient.done, timeout=0.1)
- self.assertEqual(BasicClient.received_data, expected_response)
+ assert BasicClient.received_data == expected_response
transport.close()
async def test_async_udp_server_exception(self):
"""Test sending garbage data on a TCP socket should drop the connection"""
BasicClient.dataTo = b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
- BasicClient.connected = self.loop.create_future()
- BasicClient.done = self.loop.create_future()
+ BasicClient.connected = asyncio.Future()
+ BasicClient.done = asyncio.Future()
await self.start_server(do_udp=True)
- with patch(
+ with mock.patch(
"pymodbus.transaction.ModbusSocketFramer.processIncomingPacket",
- new_callable=lambda: Mock(side_effect=Exception),
+ new_callable=lambda: mock.Mock(side_effect=Exception),
):
# get the random server port pylint: disable=protected-access
random_port = self.server.protocol._sock.getsockname()[1]
@@ -428,18 +409,16 @@ async def test_async_udp_server_exception(self):
BasicClient, remote_addr=("127.0.0.1", random_port)
)
await asyncio.wait_for(BasicClient.connected, timeout=0.1)
- self.assertFalse(BasicClient.done.done())
- self.assertFalse(
- self.server.protocol._sock._closed # pylint: disable=protected-access
- )
+ assert not BasicClient.done.done()
+ assert not self.server.protocol._sock._closed
async def test_async_tcp_server_exception(self):
"""Send garbage data on a TCP socket should drop the connection"""
BasicClient.data = b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"
await self.start_server()
- with patch(
+ with mock.patch(
"pymodbus.transaction.ModbusSocketFramer.processIncomingPacket",
- new_callable=lambda: Mock(side_effect=Exception),
+ new_callable=lambda: mock.Mock(side_effect=Exception),
):
await self.connect_server()
await asyncio.wait_for(BasicClient.eof, timeout=0.1)
diff --git a/test/test_server_context.py b/test/test_server_context.py
index d4c48736c..9632ae33b 100644
--- a/test/test_server_context.py
+++ b/test/test_server_context.py
@@ -1,26 +1,24 @@
"""Test server context."""
-import unittest
+import pytest
from pymodbus.datastore import ModbusServerContext, ModbusSlaveContext
from pymodbus.exceptions import NoSuchSlaveException
-class ModbusServerSingleContextTest(unittest.TestCase):
- """This is the unittest for the pymodbus.datastore.ModbusServerContext using a single slave context."""
+class TestServerSingleContext:
+ """This is the test for the pymodbus.datastore.ModbusServerContext using a single slave context."""
- def setUp(self):
+ slave = ModbusSlaveContext()
+ context = None
+
+ def setup_method(self):
"""Set up the test environment"""
- self.slave = ModbusSlaveContext()
self.context = ModbusServerContext(slaves=self.slave, single=True)
- def tearDown(self):
- """Clean up the test environment"""
- del self.context
-
def test_single_context_gets(self):
"""Test getting on a single context"""
for slave_id in range(0, 0xFF):
- self.assertEqual(self.slave, self.context[slave_id])
+ assert self.slave == self.context[slave_id]
def test_single_context_deletes(self):
"""Test removing on multiple context"""
@@ -28,68 +26,70 @@ def test_single_context_deletes(self):
def _test():
del self.context[0x00]
- self.assertRaises(NoSuchSlaveException, _test)
+ with pytest.raises(NoSuchSlaveException):
+ _test()
def test_single_context_iter(self):
"""Test iterating over a single context"""
expected = (0, self.slave)
for slave in self.context:
- self.assertEqual(slave, expected)
+ assert slave == expected
def test_single_context_default(self):
"""Test that the single context default values work"""
self.context = ModbusServerContext()
slave = self.context[0x00]
- self.assertEqual(slave, {})
+ assert not slave
def test_single_context_set(self):
"""Test a setting a single slave context"""
slave = ModbusSlaveContext()
self.context[0x00] = slave
actual = self.context[0x00]
- self.assertEqual(slave, actual)
+ assert slave == actual
def test_single_context_register(self):
"""Test single context register."""
request_db = [1, 2, 3]
slave = ModbusSlaveContext()
slave.register(0xFF, "custom_request", request_db)
- self.assertEqual(slave.store["custom_request"], request_db)
- self.assertEqual(slave.decode(0xFF), "custom_request")
+ assert slave.store["custom_request"] == request_db
+ assert slave.decode(0xFF) == "custom_request"
-class ModbusServerMultipleContextTest(unittest.TestCase):
- """This is the unittest for the pymodbus.datastore.ModbusServerContext using multiple slave contexts."""
+class TestServerMultipleContext:
+ """This is the test for the pymodbus.datastore.ModbusServerContext using multiple slave contexts."""
- def setUp(self):
+ slaves = None
+ context = None
+
+ def setup_method(self):
"""Set up the test environment"""
self.slaves = {id: ModbusSlaveContext() for id in range(10)}
self.context = ModbusServerContext(slaves=self.slaves, single=False)
- def tearDown(self):
- """Clean up the test environment"""
- del self.context
-
def test_multiple_context_gets(self):
"""Test getting on multiple context"""
for slave_id in range(0, 10):
- self.assertEqual(self.slaves[slave_id], self.context[slave_id])
+ assert self.slaves[slave_id] == self.context[slave_id]
def test_multiple_context_deletes(self):
"""Test removing on multiple context"""
del self.context[0x00]
- self.assertRaises(NoSuchSlaveException, lambda: self.context[0x00])
+ with pytest.raises(NoSuchSlaveException):
+ self.context[0x00]()
def test_multiple_context_iter(self):
"""Test iterating over multiple context"""
for slave_id, slave in self.context:
- self.assertEqual(slave, self.slaves[slave_id])
- self.assertTrue(slave_id in self.context)
+ assert slave == self.slaves[slave_id]
+ assert slave_id in self.context
def test_multiple_context_default(self):
"""Test that the multiple context default values work"""
self.context = ModbusServerContext(single=False)
- self.assertRaises(NoSuchSlaveException, lambda: self.context[0x00])
+ with pytest.raises(NoSuchSlaveException):
+ self.context[0x00]()
def test_multiple_context_set(self):
"""Test a setting multiple slave contexts"""
@@ -98,4 +98,4 @@ def test_multiple_context_set(self):
self.context[slave_id] = slave
for slave_id, slave in iter(slaves.items()):
actual = self.context[slave_id]
- self.assertEqual(slave, actual)
+ assert slave == actual
diff --git a/test/test_server_multidrop.py b/test/test_server_multidrop.py
new file mode 100644
index 000000000..a13b76747
--- /dev/null
+++ b/test/test_server_multidrop.py
@@ -0,0 +1,169 @@
+"""Test server working as slave on a multidrop RS485 line."""
+from unittest import mock
+
+import pytest
+
+from pymodbus.framer.rtu_framer import ModbusRtuFramer
+from pymodbus.server.async_io import ServerDecoder
+
+
+class TestMultidrop:
+ """Test that server works on a multidrop line."""
+
+ slaves = [2]
+
+ good_frame = b"\x02\x03\x00\x01\x00}\xd4\x18"
+
+ @pytest.fixture(name="framer")
+ def fixture_framer(self):
+ """Prepare framer."""
+ return ModbusRtuFramer(ServerDecoder())
+
+ @pytest.fixture(name="callback")
+ def fixture_callback(self):
+ """Prepare dummy callback."""
+ return mock.Mock()
+
+ def test_ok_frame(self, framer, callback):
+ """Test ok frame."""
+ serial_event = self.good_frame
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_ok_2frame(self, framer, callback):
+ """Test ok frame."""
+ serial_event = self.good_frame + self.good_frame
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ assert callback.call_count == 2
+
+ def test_bad_crc(self, framer, callback):
+ """Test bad crc."""
+ serial_event = b"\x02\x03\x00\x01\x00}\xd4\x19" # Manually mangled crc
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_not_called()
+
+ def test_wrong_id(self, framer, callback):
+ """Test frame wrong id"""
+ serial_event = b"\x01\x03\x00\x01\x00}\xd4+" # Frame with good CRC but other id
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_not_called()
+
+ def test_big_split_response_frame_from_other_id(self, framer, callback):
+ """Test split response."""
+ # This is a single *response* from device id 1 after being queried for 125 holding register values
+ # Because the response is so long it spans several serial events
+ serial_events = [
+ b"\x01\x03\xfa\xc4y\xc0\x00\xc4y\xc0\x00\xc4y\xc0\x00\xc4y\xc0\x00\xc4y\xc0\x00Dz\x00\x00C\x96\x00\x00",
+ b"?\x05\x1e\xb8DH\x00\x00D\x96\x00\x00D\xfa\x00\x00DH\x00\x00D\x96\x00\x00D\xfa\x00\x00DH\x00",
+ b"\x00D\x96\x00\x00D\xfa\x00\x00B\x96\x00\x00B\xb4\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ b"\x00\x00\x00\x00\x00\x00\x00N,",
+ ]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_not_called()
+
+ def test_split_frame(self, framer, callback):
+ """Test split frame."""
+ serial_events = [self.good_frame[:5], self.good_frame[5:]]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_complete_frame_trailing_data_without_id(self, framer, callback):
+ """Test trailing data."""
+ garbage = b"\x05\x04\x03" # without id
+ serial_event = garbage + self.good_frame
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_complete_frame_trailing_data_with_id(self, framer, callback):
+ """Test trailing data."""
+ garbage = b"\x05\x04\x03\x02\x01\x00" # with id
+ serial_event = garbage + self.good_frame
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_split_frame_trailing_data_with_id(self, framer, callback):
+ """Test split frame."""
+ garbage = b"\x05\x04\x03\x02\x01\x00"
+ serial_events = [garbage + self.good_frame[:5], self.good_frame[5:]]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_coincidental_1(self, framer, callback):
+ """Test conincidental."""
+ garbage = b"\x02\x90\x07"
+ serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_coincidental_2(self, framer, callback):
+ """Test conincidental."""
+ garbage = b"\x02\x10\x07"
+ serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_coincidental_3(self, framer, callback):
+ """Test conincidental."""
+ garbage = b"\x02\x10\x07\x10"
+ serial_events = [garbage, self.good_frame[:5], self.good_frame[5:]]
+ for serial_event in serial_events:
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+ callback.assert_called_once()
+
+ def test_wrapped_frame(self, framer, callback):
+ """Test wrapped frame."""
+ garbage = b"\x05\x04\x03\x02\x01\x00"
+ serial_event = garbage + self.good_frame + garbage
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+
+ # We probably should not respond in this case; in this case we've likely become desynchronized
+ # i.e. this probably represents a case where a command came for us, but we didn't get
+ # to the serial buffer in time (some other co-routine or perhaps a block on the USB bus)
+ # and the master moved on and queried another device
+ callback.assert_called_once()
+
+ def test_frame_with_trailing_data(self, framer, callback):
+ """Test trailing data."""
+ garbage = b"\x05\x04\x03\x02\x01\x00"
+ serial_event = self.good_frame + garbage
+ framer.processIncomingPacket(serial_event, callback, self.slaves)
+
+ # We should not respond in this case for identical reasons as test_wrapped_frame
+ callback.assert_called_once()
+
+ def test_getFrameStart(self, framer):
+ """Test getFrameStart."""
+ framer_ok = b"\x02\x03\x00\x01\x00}\xd4\x18"
+ framer._buffer = framer_ok # pylint: disable=protected-access
+ assert framer.getFrameStart(self.slaves, False, False)
+ assert framer_ok == framer._buffer # pylint: disable=protected-access
+
+ framer_2ok = framer_ok + framer_ok
+ framer._buffer = framer_2ok # pylint: disable=protected-access
+ assert framer.getFrameStart(self.slaves, False, False)
+ assert framer_2ok == framer._buffer # pylint: disable=protected-access
+ assert framer.getFrameStart(self.slaves, False, True)
+ assert framer_ok == framer._buffer # pylint: disable=protected-access
+
+ framer._buffer = framer_ok[:2] # pylint: disable=protected-access
+ assert not framer.getFrameStart(self.slaves, False, False)
+ assert framer_ok[:2] == framer._buffer # pylint: disable=protected-access
+
+ framer._buffer = framer_ok[:3] # pylint: disable=protected-access
+ assert not framer.getFrameStart(self.slaves, False, False)
+ assert framer_ok[:3] == framer._buffer # pylint: disable=protected-access
+
+ framer_ok = b"\xF0\x03\x00\x01\x00}\xd4\x18"
+ framer._buffer = framer_ok # pylint: disable=protected-access
+ assert not framer.getFrameStart(self.slaves, False, False)
+ assert framer._buffer == framer_ok[-3:] # pylint: disable=protected-access
diff --git a/test/test_server_task.py b/test/test_server_task.py
index d1c7bc582..df382c65d 100755
--- a/test/test_server_task.py
+++ b/test/test_server_task.py
@@ -4,6 +4,7 @@
import os
from threading import Thread
from time import sleep
+from unittest import mock
import pytest
@@ -35,7 +36,7 @@ def helper_config(request, def_type):
datablock = ModbusSequentialDataBlock(0x00, [17] * 100)
context = ModbusServerContext(
slaves=ModbusSlaveContext(
- di=datablock, co=datablock, hr=datablock, ir=datablock, unit=1
+ di=datablock, co=datablock, hr=datablock, ir=datablock, slave=1
),
single=True,
)
@@ -139,23 +140,23 @@ def helper_config(request, def_type):
return cur_m["srv"], cur["srv_args"], cur_m["cli"], cur["cli_args"]
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
async def test_async_task_no_server(comm):
"""Test normal client/server handling."""
- run_server, server_args, run_client, client_args = helper_config(comm, "async")
+ _run_server, _server_args, run_client, client_args = helper_config(comm, "async")
client = run_client(**client_args)
try:
await client.connect()
- except Exception as exc: # pylint: disable=broad-except
- assert False, f"unexpected exception: {exc}"
+ except Exception as exc:
+ raise AssertionError(f"unexpected exception: {exc}") from exc
await asyncio.sleep(0.1)
with pytest.raises((asyncio.exceptions.TimeoutError, ConnectionException)):
await client.read_coils(1, 1, slave=0x01)
- await client.close()
+ client.close()
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
async def test_async_task_ok(comm):
"""Test normal client/server handling."""
@@ -166,29 +167,68 @@ async def test_async_task_ok(comm):
client = run_client(**client_args)
await client.connect()
await asyncio.sleep(0.1)
- assert client._connected # pylint: disable=protected-access
+ assert client.transport
rr = await client.read_coils(1, 1, slave=0x01)
assert len(rr.bits) == 8
- await client.close()
+ client.close()
await asyncio.sleep(0.1)
- assert not client._connected # pylint: disable=protected-access
+ assert not client.transport
await server.ServerAsyncStop()
+ task.cancel()
await task
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
-async def test_async_task_server_stop(comm):
+async def test_async_task_reuse(comm):
"""Test normal client/server handling."""
run_server, server_args, run_client, client_args = helper_config(comm, "async")
+
task = asyncio.create_task(run_server(**server_args))
await asyncio.sleep(0.1)
client = run_client(**client_args)
await client.connect()
- assert client._connected # pylint: disable=protected-access
+ await asyncio.sleep(0.1)
+ assert client.transport
+ rr = await client.read_coils(1, 1, slave=0x01)
+ assert len(rr.bits) == 8
+
+ client.close()
+ await asyncio.sleep(0.1)
+ assert not client.transport
+
+ await client.connect()
+ await asyncio.sleep(0.1)
+ assert client.transport
+ rr = await client.read_coils(1, 1, slave=0x01)
+ assert len(rr.bits) == 8
+
+ client.close()
+ await asyncio.sleep(0.1)
+ assert not client.transport
+
+ await server.ServerAsyncStop()
+ task.cancel()
+ await task
+
+
+@pytest.mark.xdist_group(name="server_serialize")
+@pytest.mark.parametrize("comm", TEST_TYPES)
+async def test_async_task_server_stop(comm):
+ """Test normal client/server handling."""
+ run_server, server_args, run_client, client_args = helper_config(comm, "async")
+ task = asyncio.create_task(run_server(**server_args))
+ await asyncio.sleep(0.5)
+
+ on_reconnect_callback = mock.Mock()
+
+ client = run_client(**client_args, on_reconnect_callback=on_reconnect_callback)
+ await client.connect()
+ assert client.transport
rr = await client.read_coils(1, 1, slave=0x01)
assert len(rr.bits) == 8
+ on_reconnect_callback.assert_not_called()
# Server breakdown
await server.ServerAsyncStop()
@@ -196,31 +236,30 @@ async def test_async_task_server_stop(comm):
with pytest.raises((ConnectionException, asyncio.exceptions.TimeoutError)):
rr = await client.read_coils(1, 1, slave=0x01)
- assert not client._connected # pylint: disable=protected-access
+ assert not client.transport
# Server back online
task = asyncio.create_task(run_server(**server_args))
- await asyncio.sleep(0.1)
+ await asyncio.sleep(1)
timer_allowed = 100
- while not client._connected: # pylint: disable=protected-access
+ while not client.transport and timer_allowed:
await asyncio.sleep(0.1)
timer_allowed -= 1
- if not timer_allowed:
- assert False, "client do not reconnect"
- assert client._connected # pylint: disable=protected-access
+ assert client.transport, "client do not reconnect"
+ # TBD on_reconnect_callback.assert_called()
rr = await client.read_coils(1, 1, slave=0x01)
assert len(rr.bits) == 8
- await client.close()
+ client.close()
await asyncio.sleep(0.5)
- assert not client._connected # pylint: disable=protected-access
+ assert not client.transport
await server.ServerAsyncStop()
await task
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
def test_sync_task_no_server(comm):
"""Test normal client/server handling."""
@@ -228,8 +267,8 @@ def test_sync_task_no_server(comm):
client = run_client(**client_args)
try:
client.connect()
- except Exception as exc: # pylint: disable=broad-except
- assert False, f"unexpected exception: {exc}"
+ except Exception as exc:
+ raise AssertionError(f"unexpected exception: {exc}") from exc
sleep(0.1)
if comm == "udp":
rr = client.read_coils(1, 1, slave=0x01)
@@ -240,7 +279,7 @@ def test_sync_task_no_server(comm):
client.close()
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
def test_sync_task_ok(comm):
"""Test normal client/server handling."""
@@ -265,7 +304,7 @@ def test_sync_task_ok(comm):
thread.join()
-@pytest.mark.xdist_group(name="task_serialize")
+@pytest.mark.xdist_group(name="server_serialize")
@pytest.mark.parametrize("comm", TEST_TYPES)
def test_sync_task_server_stop(comm):
"""Test normal client/server handling."""
@@ -304,7 +343,7 @@ def test_sync_task_server_stop(comm):
sleep(0.1)
timer_allowed -= 1
if not timer_allowed:
- assert False, "client do not reconnect"
+ pytest.fail("client do not reconnect")
assert client.socket
rr = client.read_coils(1, 1, slave=0x01)
diff --git a/test/test_simulator.py b/test/test_simulator.py
index 036be5a40..ead7bfc9e 100644
--- a/test/test_simulator.py
+++ b/test/test_simulator.py
@@ -1,18 +1,10 @@
"""Test datastore."""
-import asyncio
import copy
-import logging
import pytest
-from examples.client_async import setup_async_client
-from examples.helper import Commandline
-from examples.server_simulator import run_server_simulator, setup_simulator
-from pymodbus import pymodbus_apply_logging_config
from pymodbus.datastore import ModbusSimulatorContext
from pymodbus.datastore.simulator import Cell, CellType, Label
-from pymodbus.server import ServerAsyncStop
-from pymodbus.transaction import ModbusSocketFramer
FX_READ_BIT = 1
@@ -412,7 +404,7 @@ def test_simulator_get_text(self):
assert cell.count_write == str(reg.count_write), f"at register {test_reg}"
@pytest.mark.parametrize(
- "func,addr",
+ ("func", "addr"),
[
(FX_READ_BIT, 12),
(FX_READ_REG, 16),
@@ -462,27 +454,3 @@ def test_simulator_action_reset(self):
]
with pytest.raises(RuntimeError):
exc_simulator.getValues(FX_READ_REG, addr, 1)
-
- async def test_simulator_example(self):
- """Test datastore simulator example."""
- pymodbus_apply_logging_config(logging.DEBUG)
- # JAN activate.
- args = Commandline.copy()
- args.comm = "tcp"
- args.framer = ModbusSocketFramer
- args.port = 5051
- run_args = setup_simulator(
- args, setup=self.default_config, actions=self.custom_actions
- )
- if args:
- return # Turn off for now.
- asyncio.create_task(run_server_simulator(run_args))
- await asyncio.sleep(0.1)
- client = setup_async_client(args)
- await client.connect()
- assert client.connected
-
- rr = await client.read_holding_registers(16, 1, slave=1)
- assert rr.registers
- await client.close()
- await ServerAsyncStop()
diff --git a/test/test_sparse_datastore.py b/test/test_sparse_datastore.py
new file mode 100644
index 000000000..409fc3e01
--- /dev/null
+++ b/test/test_sparse_datastore.py
@@ -0,0 +1,25 @@
+"""Test framers."""
+
+
+from pymodbus.datastore import ModbusSparseDataBlock
+
+
+def test_check_sparsedatastore():
+ """Test check frame."""
+ data_in_block = {
+ 1: 6720,
+ 2: 130,
+ 30: [0x0D, 0xFE],
+ 105: [1, 2, 3, 4],
+ 20000: [45, 241, 48],
+ 20008: 38,
+ 48140: [0x4208, 0xCCCD],
+ }
+ datablock = ModbusSparseDataBlock(data_in_block)
+ for key, entry in data_in_block.items():
+ if isinstance(entry, int):
+ entry = [entry]
+ for value in entry:
+ assert datablock.validate(key, 1)
+ assert datablock.getValues(key, 1) == [value]
+ key += 1
diff --git a/test/test_transaction.py b/test/test_transaction.py
index 87cacb5b5..967f6e443 100755
--- a/test/test_transaction.py
+++ b/test/test_transaction.py
@@ -1,8 +1,9 @@
"""Test transaction."""
-import unittest
from binascii import a2b_hex
from itertools import count
-from unittest.mock import MagicMock, patch
+from unittest import mock
+
+import pytest
from pymodbus.exceptions import (
InvalidMessageReceivedException,
@@ -25,15 +26,24 @@
TEST_MESSAGE = b"\x7b\x01\x03\x00\x00\x00\x05\x85\xC9\x7d"
-class ModbusTransactionTest( # pylint: disable=too-many-public-methods
- unittest.TestCase
-):
+class TestTransaction: # pylint: disable=too-many-public-methods
"""Unittest for the pymodbus.transaction module."""
+ client = None
+ decoder = None
+ _tcp = None
+ _tls = None
+ _rtu = None
+ _ascii = None
+ _binary = None
+ _manager = None
+ _queue_manager = None
+ _tm = None
+
# ----------------------------------------------------------------------- #
# Test Construction
# ----------------------------------------------------------------------- #
- def setUp(self):
+ def setup_method(self):
"""Set up the test environment"""
self.client = None
self.decoder = ServerDecoder()
@@ -46,31 +56,24 @@ def setUp(self):
self._queue_manager = FifoTransactionManager(self.client)
self._tm = ModbusTransactionManager(self.client)
- def tearDown(self):
- """Clean up the test environment"""
- del self._manager
- del self._tcp
- del self._tls
- del self._rtu
- del self._ascii
-
# ----------------------------------------------------------------------- #
# Base transaction manager
# ----------------------------------------------------------------------- #
def test_calculate_expected_response_length(self):
"""Test calculate expected response length."""
- self._tm.client = MagicMock()
- self._tm.client.framer = MagicMock()
+ self._tm.client = mock.MagicMock()
+ self._tm.client.framer = mock.MagicMock()
self._tm._set_adu_size() # pylint: disable=protected-access
- self.assertEqual(
- self._tm._calculate_response_length(0), # pylint: disable=protected-access
- None,
+ assert (
+ not self._tm._calculate_response_length( # pylint: disable=protected-access
+ 0
+ )
)
self._tm.base_adu_size = 10
- self.assertEqual(
- self._tm._calculate_response_length(5), # pylint: disable=protected-access
- 15,
+ assert (
+ self._tm._calculate_response_length(5) # pylint: disable=protected-access
+ == 15
)
def test_calculate_exception_length(self):
@@ -83,7 +86,7 @@ def test_calculate_exception_length(self):
("tls", 2),
("dummy", None),
):
- self._tm.client = MagicMock()
+ self._tm.client = mock.MagicMock()
if framer == "ascii":
self._tm.client.framer = self._ascii
elif framer == "binary":
@@ -95,97 +98,99 @@ def test_calculate_exception_length(self):
elif framer == "tls":
self._tm.client.framer = self._tls
else:
- self._tm.client.framer = MagicMock()
+ self._tm.client.framer = mock.MagicMock()
self._tm._set_adu_size() # pylint: disable=protected-access
- self.assertEqual(
- self._tm._calculate_exception_length(), # pylint: disable=protected-access
- exception_length,
+ assert (
+ self._tm._calculate_exception_length() # pylint: disable=protected-access
+ == exception_length
)
- @patch("pymodbus.transaction.time")
+ @mock.patch("pymodbus.transaction.time")
def test_execute(self, mock_time):
"""Test execute."""
mock_time.time.side_effect = count()
- client = MagicMock()
+ client = mock.MagicMock()
client.framer = self._ascii
client.framer._buffer = b"deadbeef" # pylint: disable=protected-access
- client.framer.processIncomingPacket = MagicMock()
+ client.framer.processIncomingPacket = mock.MagicMock()
client.framer.processIncomingPacket.return_value = None
- client.framer.buildPacket = MagicMock()
+ client.framer.buildPacket = mock.MagicMock()
client.framer.buildPacket.return_value = b"deadbeef"
- client.framer.sendPacket = MagicMock()
+ client.framer.sendPacket = mock.MagicMock()
client.framer.sendPacket.return_value = len(b"deadbeef")
- client.framer.decode_data = MagicMock()
- client.framer.decode_data.return_value = {"unit": 1, "fcode": 222, "length": 27}
- request = MagicMock()
+ client.framer.decode_data = mock.MagicMock()
+ client.framer.decode_data.return_value = {
+ "slave": 1,
+ "fcode": 222,
+ "length": 27,
+ }
+ request = mock.MagicMock()
request.get_response_pdu_size.return_value = 10
- request.unit_id = 1
+ request.slave_id = 1
request.function_code = 222
trans = ModbusTransactionManager(client)
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
return_value=b"abcdef"
)
- self.assertEqual(trans.retries, 3)
- self.assertEqual(trans.retry_on_empty, False)
+ assert trans.retries == 3
+ assert not trans.retry_on_empty
- trans.getTransaction = MagicMock()
+ trans.getTransaction = mock.MagicMock()
trans.getTransaction.return_value = "response"
response = trans.execute(request)
- self.assertEqual(response, "response")
+ assert response == "response"
# No response
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
return_value=b"abcdef"
)
trans.transactions = []
- trans.getTransaction = MagicMock()
+ trans.getTransaction = mock.MagicMock()
trans.getTransaction.return_value = None
response = trans.execute(request)
- self.assertIsInstance(response, ModbusIOException)
+ assert isinstance(response, ModbusIOException)
# No response with retries
trans.retry_on_empty = True
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
side_effect=iter([b"", b"abcdef"])
)
response = trans.execute(request)
- self.assertIsInstance(response, ModbusIOException)
+ assert isinstance(response, ModbusIOException)
# wrong handle_local_echo
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
side_effect=iter([b"abcdef", b"deadbe", b"123456"])
)
client.handle_local_echo = True
trans.retry_on_empty = False
trans.retry_on_invalid = False
- self.assertEqual(
- trans.execute(request).message, "[Input/Output] Wrong local echo"
- )
+ assert trans.execute(request).message == "[Input/Output] Wrong local echo"
client.handle_local_echo = False
# retry on invalid response
trans.retry_on_invalid = True
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
side_effect=iter([b"", b"abcdef", b"deadbe", b"123456"])
)
response = trans.execute(request)
- self.assertIsInstance(response, ModbusIOException)
+ assert isinstance(response, ModbusIOException)
# Unable to decode response
- trans._recv = MagicMock( # pylint: disable=protected-access
+ trans._recv = mock.MagicMock( # pylint: disable=protected-access
side_effect=ModbusIOException()
)
- client.framer.processIncomingPacket.side_effect = MagicMock(
+ client.framer.processIncomingPacket.side_effect = mock.MagicMock(
side_effect=ModbusIOException()
)
- self.assertIsInstance(trans.execute(request), ModbusIOException)
+ assert isinstance(trans.execute(request), ModbusIOException)
# Broadcast
client.params.broadcast_enable = True
- request.unit_id = 0
+ request.slave_id = 0
response = trans.execute(request)
- self.assertEqual(response, b"Broadcast write sent - no response expected")
+ assert response == b"Broadcast write sent - no response expected"
# ----------------------------------------------------------------------- #
# Dictionary based transaction manager
@@ -194,9 +199,9 @@ def test_execute(self, mock_time):
def test_dict_transaction_manager_tid(self):
"""Test the dict transaction manager TID"""
for tid in range(1, self._manager.getNextTID() + 10):
- self.assertEqual(tid + 1, self._manager.getNextTID())
+ assert tid + 1 == self._manager.getNextTID()
self._manager.reset()
- self.assertEqual(1, self._manager.getNextTID())
+ assert self._manager.getNextTID() == 1
def test_get_dict_fifo_transaction_manager_transaction(self):
"""Test the dict transaction manager"""
@@ -212,7 +217,7 @@ class Request: # pylint: disable=too-few-public-methods
handle.message = b"testing" # pylint: disable=attribute-defined-outside-init
self._manager.addTransaction(handle)
result = self._manager.getTransaction(handle.transaction_id)
- self.assertEqual(handle.message, result.message)
+ assert handle.message == result.message
def test_delete_dict_fifo_transaction_manager_transaction(self):
"""Test the dict transaction manager"""
@@ -229,7 +234,7 @@ class Request: # pylint: disable=too-few-public-methods
self._manager.addTransaction(handle)
self._manager.delTransaction(handle.transaction_id)
- self.assertEqual(None, self._manager.getTransaction(handle.transaction_id))
+ assert not self._manager.getTransaction(handle.transaction_id)
# ----------------------------------------------------------------------- #
# Queue based transaction manager
@@ -237,9 +242,9 @@ class Request: # pylint: disable=too-few-public-methods
def test_fifo_transaction_manager_tid(self):
"""Test the fifo transaction manager TID"""
for tid in range(1, self._queue_manager.getNextTID() + 10):
- self.assertEqual(tid + 1, self._queue_manager.getNextTID())
+ assert tid + 1 == self._queue_manager.getNextTID()
self._queue_manager.reset()
- self.assertEqual(1, self._queue_manager.getNextTID())
+ assert self._queue_manager.getNextTID() == 1
def test_get_fifo_transaction_manager_transaction(self):
"""Test the fifo transaction manager"""
@@ -255,7 +260,7 @@ class Request: # pylint: disable=too-few-public-methods
handle.message = b"testing" # pylint: disable=attribute-defined-outside-init
self._queue_manager.addTransaction(handle)
result = self._queue_manager.getTransaction(handle.transaction_id)
- self.assertEqual(handle.message, result.message)
+ assert handle.message == result.message
def test_delete_fifo_transaction_manager_transaction(self):
"""Test the fifo transaction manager"""
@@ -272,9 +277,7 @@ class Request: # pylint: disable=too-few-public-methods
self._queue_manager.addTransaction(handle)
self._queue_manager.delTransaction(handle.transaction_id)
- self.assertEqual(
- None, self._queue_manager.getTransaction(handle.transaction_id)
- )
+ assert not self._queue_manager.getTransaction(handle.transaction_id)
# ----------------------------------------------------------------------- #
# TCP tests
@@ -282,23 +285,23 @@ class Request: # pylint: disable=too-few-public-methods
def test_tcp_framer_transaction_ready(self):
"""Test a tcp frame transaction"""
msg = b"\x00\x01\x12\x34\x00\x04\xff\x02\x12\x34"
- self.assertFalse(self._tcp.isFrameReady())
- self.assertFalse(self._tcp.checkFrame())
+ assert not self._tcp.isFrameReady()
+ assert not self._tcp.checkFrame()
self._tcp.addToFrame(msg)
- self.assertTrue(self._tcp.isFrameReady())
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.isFrameReady()
+ assert self._tcp.checkFrame()
self._tcp.advanceFrame()
- self.assertFalse(self._tcp.isFrameReady())
- self.assertFalse(self._tcp.checkFrame())
- self.assertEqual(b"", self._ascii.getFrame())
+ assert not self._tcp.isFrameReady()
+ assert not self._tcp.checkFrame()
+ assert self._ascii.getFrame() == b""
def test_tcp_framer_transaction_full(self):
"""Test a full tcp frame transaction"""
msg = b"\x00\x01\x12\x34\x00\x04\xff\x02\x12\x34"
self._tcp.addToFrame(msg)
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg[7:], result)
+ assert result == msg[7:]
self._tcp.advanceFrame()
def test_tcp_framer_transaction_half(self):
@@ -306,13 +309,13 @@ def test_tcp_framer_transaction_half(self):
msg1 = b"\x00\x01\x12\x34\x00"
msg2 = b"\x04\xff\x02\x12\x34"
self._tcp.addToFrame(msg1)
- self.assertFalse(self._tcp.checkFrame())
+ assert not self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(b"", result)
+ assert result == b""
self._tcp.addToFrame(msg2)
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg2[2:], result)
+ assert result == msg2[2:]
self._tcp.advanceFrame()
def test_tcp_framer_transaction_half2(self):
@@ -320,13 +323,13 @@ def test_tcp_framer_transaction_half2(self):
msg1 = b"\x00\x01\x12\x34\x00\x04\xff"
msg2 = b"\x02\x12\x34"
self._tcp.addToFrame(msg1)
- self.assertFalse(self._tcp.checkFrame())
+ assert not self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(b"", result)
+ assert result == b""
self._tcp.addToFrame(msg2)
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg2, result)
+ assert msg2 == result
self._tcp.advanceFrame()
def test_tcp_framer_transaction_half3(self):
@@ -334,13 +337,13 @@ def test_tcp_framer_transaction_half3(self):
msg1 = b"\x00\x01\x12\x34\x00\x04\xff\x02\x12"
msg2 = b"\x34"
self._tcp.addToFrame(msg1)
- self.assertFalse(self._tcp.checkFrame())
+ assert not self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg1[7:], result)
+ assert result == msg1[7:]
self._tcp.addToFrame(msg2)
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg1[7:] + msg2, result)
+ assert result == msg1[7:] + msg2
self._tcp.advanceFrame()
def test_tcp_framer_transaction_short(self):
@@ -348,15 +351,15 @@ def test_tcp_framer_transaction_short(self):
msg1 = b"\x99\x99\x99\x99\x00\x01\x00\x01"
msg2 = b"\x00\x01\x12\x34\x00\x04\xff\x02\x12\x34"
self._tcp.addToFrame(msg1)
- self.assertFalse(self._tcp.checkFrame())
+ assert not self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(b"", result)
+ assert result == b""
self._tcp.advanceFrame()
self._tcp.addToFrame(msg2)
- self.assertEqual(10, len(self._tcp._buffer)) # pylint: disable=protected-access
- self.assertTrue(self._tcp.checkFrame())
+ assert len(self._tcp._buffer) == 10 # pylint: disable=protected-access
+ assert self._tcp.checkFrame()
result = self._tcp.getFrame()
- self.assertEqual(msg2[7:], result)
+ assert result == msg2[7:]
self._tcp.advanceFrame()
def test_tcp_framer_populate(self):
@@ -364,14 +367,14 @@ def test_tcp_framer_populate(self):
expected = ModbusRequest()
expected.transaction_id = 0x0001
expected.protocol_id = 0x1234
- expected.unit_id = 0xFF
+ expected.slave_id = 0xFF
msg = b"\x00\x01\x12\x34\x00\x04\xff\x02\x12\x34"
self._tcp.addToFrame(msg)
- self.assertTrue(self._tcp.checkFrame())
+ assert self._tcp.checkFrame()
actual = ModbusRequest()
self._tcp.populateResult(actual)
- for name in ("transaction_id", "protocol_id", "unit_id"):
- self.assertEqual(getattr(expected, name), getattr(actual, name))
+ for name in ("transaction_id", "protocol_id", "slave_id"):
+ assert getattr(expected, name) == getattr(actual, name)
self._tcp.advanceFrame()
def test_tcp_framer_packet(self):
@@ -381,11 +384,11 @@ def test_tcp_framer_packet(self):
message = ModbusRequest()
message.transaction_id = 0x0001
message.protocol_id = 0x1234
- message.unit_id = 0xFF
+ message.slave_id = 0xFF
message.function_code = 0x01
expected = b"\x00\x01\x12\x34\x00\x02\xff\x01"
actual = self._tcp.buildPacket(message)
- self.assertEqual(expected, actual)
+ assert expected == actual
ModbusRequest.encode = old_encode
# ----------------------------------------------------------------------- #
@@ -394,23 +397,23 @@ def test_tcp_framer_packet(self):
def test_framer_tls_framer_transaction_ready(self):
"""Test a tls frame transaction"""
msg = b"\x01\x12\x34\x00\x08"
- self.assertFalse(self._tls.isFrameReady())
- self.assertFalse(self._tls.checkFrame())
+ assert not self._tls.isFrameReady()
+ assert not self._tls.checkFrame()
self._tls.addToFrame(msg)
- self.assertTrue(self._tls.isFrameReady())
- self.assertTrue(self._tls.checkFrame())
+ assert self._tls.isFrameReady()
+ assert self._tls.checkFrame()
self._tls.advanceFrame()
- self.assertFalse(self._tls.isFrameReady())
- self.assertFalse(self._tls.checkFrame())
- self.assertEqual(b"", self._tls.getFrame())
+ assert not self._tls.isFrameReady()
+ assert not self._tls.checkFrame()
+ assert self._tls.getFrame() == b""
def test_framer_tls_framer_transaction_full(self):
"""Test a full tls frame transaction"""
msg = b"\x01\x12\x34\x00\x08"
self._tls.addToFrame(msg)
- self.assertTrue(self._tls.checkFrame())
+ assert self._tls.checkFrame()
result = self._tls.getFrame()
- self.assertEqual(msg[0:], result)
+ assert result == msg[0:]
self._tls.advanceFrame()
def test_framer_tls_framer_transaction_half(self):
@@ -418,13 +421,13 @@ def test_framer_tls_framer_transaction_half(self):
msg1 = b""
msg2 = b"\x01\x12\x34\x00\x08"
self._tls.addToFrame(msg1)
- self.assertFalse(self._tls.checkFrame())
+ assert not self._tls.checkFrame()
result = self._tls.getFrame()
- self.assertEqual(b"", result)
+ assert result == b""
self._tls.addToFrame(msg2)
- self.assertTrue(self._tls.checkFrame())
+ assert self._tls.checkFrame()
result = self._tls.getFrame()
- self.assertEqual(msg2[0:], result)
+ assert result == msg2[0:]
self._tls.advanceFrame()
def test_framer_tls_framer_transaction_short(self):
@@ -432,15 +435,15 @@ def test_framer_tls_framer_transaction_short(self):
msg1 = b""
msg2 = b"\x01\x12\x34\x00\x08"
self._tls.addToFrame(msg1)
- self.assertFalse(self._tls.checkFrame())
+ assert not self._tls.checkFrame()
result = self._tls.getFrame()
- self.assertEqual(b"", result)
+ assert result == b""
self._tls.advanceFrame()
self._tls.addToFrame(msg2)
- self.assertEqual(5, len(self._tls._buffer)) # pylint: disable=protected-access
- self.assertTrue(self._tls.checkFrame())
+ assert len(self._tls._buffer) == 5 # pylint: disable=protected-access
+ assert self._tls.checkFrame()
result = self._tls.getFrame()
- self.assertEqual(msg2[0:], result)
+ assert result == msg2[0:]
self._tls.advanceFrame()
def test_framer_tls_framer_decode(self):
@@ -448,39 +451,36 @@ def test_framer_tls_framer_decode(self):
msg1 = b""
msg2 = b"\x01\x12\x34\x00\x08"
result = self._tls.decode_data(msg1)
- self.assertEqual({}, result)
+ assert not result
result = self._tls.decode_data(msg2)
- self.assertEqual({"fcode": 1}, result)
+ assert result == {"fcode": 1}
self._tls.advanceFrame()
def test_framer_tls_incoming_packet(self):
"""Framer tls incoming packet."""
msg = b"\x01\x12\x34\x00\x08"
- unit = 0x01
+ slave = 0x01
def mock_callback():
"""Mock callback."""
- self._tls._process = MagicMock() # pylint: disable=protected-access
- self._tls.isFrameReady = MagicMock(return_value=False)
- self._tls.processIncomingPacket(msg, mock_callback, unit)
- self.assertEqual(msg, self._tls.getRawFrame())
+ self._tls._process = mock.MagicMock() # pylint: disable=protected-access
+ self._tls.isFrameReady = mock.MagicMock(return_value=False)
+ self._tls.processIncomingPacket(msg, mock_callback, slave)
+ assert msg == self._tls.getRawFrame()
self._tls.advanceFrame()
- self._tls.isFrameReady = MagicMock(return_value=True)
- self._tls._validate_unit_id = MagicMock( # pylint: disable=protected-access
- return_value=False
- )
- self._tls.processIncomingPacket(msg, mock_callback, unit)
- self.assertEqual(b"", self._tls.getRawFrame())
+ self._tls.isFrameReady = mock.MagicMock(return_value=True)
+ x = mock.MagicMock(return_value=False)
+ self._tls._validate_slave_id = x # pylint: disable=protected-access
+ self._tls.processIncomingPacket(msg, mock_callback, slave)
+ assert not self._tls.getRawFrame()
self._tls.advanceFrame()
-
- self._tls._validate_unit_id = MagicMock( # pylint: disable=protected-access
- return_value=True
- )
- self._tls.processIncomingPacket(msg, mock_callback, unit)
- self.assertEqual(msg, self._tls.getRawFrame())
+ x = mock.MagicMock(return_value=True)
+ self._tls._validate_slave_id = x # pylint: disable=protected-access
+ self._tls.processIncomingPacket(msg, mock_callback, slave)
+ assert msg == self._tls.getRawFrame()
self._tls.advanceFrame()
def test_framer_tls_process(self):
@@ -496,37 +496,30 @@ def __init__(self, code):
def mock_callback(_arg):
"""Mock callback."""
- self._tls.decoder.decode = MagicMock(return_value=None)
- self.assertRaises(
- ModbusIOException,
- lambda: self._tls._process( # pylint: disable=protected-access
- mock_callback
- ),
- )
+ self._tls.decoder.decode = mock.MagicMock(return_value=None)
+ with pytest.raises(ModbusIOException):
+ self._tls._process(mock_callback) # pylint: disable=protected-access
result = MockResult(0x01)
- self._tls.decoder.decode = MagicMock(return_value=result)
- self.assertRaises(
- InvalidMessageReceivedException,
- lambda: self._tls._process( # pylint: disable=protected-access
+ self._tls.decoder.decode = mock.MagicMock(return_value=result)
+ with pytest.raises(InvalidMessageReceivedException):
+ self._tls._process( # pylint: disable=protected-access
mock_callback, error=True
- ),
- )
-
+ )
self._tls._process(mock_callback) # pylint: disable=protected-access
- self.assertEqual(b"", self._tls.getRawFrame())
+ assert not self._tls.getRawFrame()
def test_framer_tls_framer_populate(self):
"""Test a tls frame packet build"""
ModbusRequest()
msg = b"\x01\x12\x34\x00\x08"
self._tls.addToFrame(msg)
- self.assertTrue(self._tls.checkFrame())
+ assert self._tls.checkFrame()
actual = ModbusRequest()
result = self._tls.populateResult( # pylint: disable=assignment-from-none
actual
)
- self.assertEqual(None, result)
+ assert not result
self._tls.advanceFrame()
def test_framer_tls_framer_packet(self):
@@ -537,7 +530,7 @@ def test_framer_tls_framer_packet(self):
message.function_code = 0x01
expected = b"\x01"
actual = self._tls.buildPacket(message)
- self.assertEqual(expected, actual)
+ assert expected == actual
ModbusRequest.encode = old_encode
# ----------------------------------------------------------------------- #
@@ -545,25 +538,25 @@ def test_framer_tls_framer_packet(self):
# ----------------------------------------------------------------------- #
def test_rtu_framer_transaction_ready(self):
"""Test if the checks for a complete frame work"""
- self.assertFalse(self._rtu.isFrameReady())
+ assert not self._rtu.isFrameReady()
msg_parts = [b"\x00\x01\x00", b"\x00\x00\x01\xfc\x1b"]
self._rtu.addToFrame(msg_parts[0])
- self.assertFalse(self._rtu.isFrameReady())
- self.assertFalse(self._rtu.checkFrame())
+ assert not self._rtu.isFrameReady()
+ assert not self._rtu.checkFrame()
self._rtu.addToFrame(msg_parts[1])
- self.assertTrue(self._rtu.isFrameReady())
- self.assertTrue(self._rtu.checkFrame())
+ assert self._rtu.isFrameReady()
+ assert self._rtu.checkFrame()
def test_rtu_framer_transaction_full(self):
"""Test a full rtu frame transaction"""
msg = b"\x00\x01\x00\x00\x00\x01\xfc\x1b"
stripped_msg = msg[1:-2]
self._rtu.addToFrame(msg)
- self.assertTrue(self._rtu.checkFrame())
+ assert self._rtu.checkFrame()
result = self._rtu.getFrame()
- self.assertEqual(stripped_msg, result)
+ assert stripped_msg == result
self._rtu.advanceFrame()
def test_rtu_framer_transaction_half(self):
@@ -571,12 +564,12 @@ def test_rtu_framer_transaction_half(self):
msg_parts = [b"\x00\x01\x00", b"\x00\x00\x01\xfc\x1b"]
stripped_msg = b"".join(msg_parts)[1:-2]
self._rtu.addToFrame(msg_parts[0])
- self.assertFalse(self._rtu.checkFrame())
+ assert not self._rtu.checkFrame()
self._rtu.addToFrame(msg_parts[1])
- self.assertTrue(self._rtu.isFrameReady())
- self.assertTrue(self._rtu.checkFrame())
+ assert self._rtu.isFrameReady()
+ assert self._rtu.checkFrame()
result = self._rtu.getFrame()
- self.assertEqual(stripped_msg, result)
+ assert stripped_msg == result
self._rtu.advanceFrame()
def test_rtu_framer_populate(self):
@@ -588,22 +581,21 @@ def test_rtu_framer_populate(self):
self._rtu.populateResult(request)
header_dict = self._rtu._header # pylint: disable=protected-access
- self.assertEqual(len(msg), header_dict["len"])
- self.assertEqual(int(msg[0]), header_dict["uid"])
- self.assertEqual(msg[-2:], header_dict["crc"])
-
- self.assertEqual(0x00, request.unit_id)
+ assert len(msg) == header_dict["len"]
+ assert int(msg[0]) == header_dict["uid"]
+ assert msg[-2:] == header_dict["crc"]
+ assert not request.slave_id
def test_rtu_framer_packet(self):
"""Test a rtu frame packet build"""
old_encode = ModbusRequest.encode
ModbusRequest.encode = lambda self: b""
message = ModbusRequest()
- message.unit_id = 0xFF
+ message.slave_id = 0xFF
message.function_code = 0x01
expected = b"\xff\x01\x81\x80" # only header + CRC - no data
actual = self._rtu.buildPacket(message)
- self.assertEqual(expected, actual)
+ assert expected == actual
ModbusRequest.encode = old_encode
def test_rtu_decode_exception(self):
@@ -611,7 +603,7 @@ def test_rtu_decode_exception(self):
message = b"\x00\x90\x02\x9c\x01"
self._rtu.addToFrame(message)
result = self._rtu.checkFrame()
- self.assertTrue(result)
+ assert result
def test_process(self):
"""Test process."""
@@ -626,40 +618,36 @@ def mock_callback(_arg):
"""Mock callback."""
mock_result = MockResult(code=0)
- self._rtu.getRawFrame = self._rtu.getFrame = MagicMock()
- self._rtu.decoder = MagicMock()
- self._rtu.decoder.decode = MagicMock(return_value=mock_result)
- self._rtu.populateResult = MagicMock()
- self._rtu.advanceFrame = MagicMock()
+ self._rtu.getRawFrame = self._rtu.getFrame = mock.MagicMock()
+ self._rtu.decoder = mock.MagicMock()
+ self._rtu.decoder.decode = mock.MagicMock(return_value=mock_result)
+ self._rtu.populateResult = mock.MagicMock()
+ self._rtu.advanceFrame = mock.MagicMock()
self._rtu._process(mock_callback) # pylint: disable=protected-access
self._rtu.populateResult.assert_called_with(mock_result)
self._rtu.advanceFrame.assert_called_with()
- self.assertTrue(self._rtu.advanceFrame.called)
+ assert self._rtu.advanceFrame.called
# Check errors
- self._rtu.decoder.decode = MagicMock(return_value=None)
- self.assertRaises(
- ModbusIOException,
- lambda: self._rtu._process( # pylint: disable=protected-access
- mock_callback
- ),
- )
+ self._rtu.decoder.decode = mock.MagicMock(return_value=None)
+ with pytest.raises(ModbusIOException):
+ self._rtu._process(mock_callback) # pylint: disable=protected-access
def test_rtu_process_incoming_packets(self):
"""Test rtu process incoming packets."""
mock_data = b"\x00\x01\x00\x00\x00\x01\xfc\x1b"
- unit = 0x00
+ slave = 0x00
def mock_callback():
"""Mock callback."""
- self._rtu.addToFrame = MagicMock()
- self._rtu._process = MagicMock() # pylint: disable=protected-access
- self._rtu.isFrameReady = MagicMock(return_value=False)
+ self._rtu.addToFrame = mock.MagicMock()
+ self._rtu._process = mock.MagicMock() # pylint: disable=protected-access
+ self._rtu.isFrameReady = mock.MagicMock(return_value=False)
self._rtu._buffer = mock_data # pylint: disable=protected-access
- self._rtu.processIncomingPacket(mock_data, mock_callback, unit)
+ self._rtu.processIncomingPacket(mock_data, mock_callback, slave)
# ----------------------------------------------------------------------- #
# ASCII tests
@@ -667,24 +655,24 @@ def mock_callback():
def test_ascii_framer_transaction_ready(self):
"""Test a ascii frame transaction"""
msg = b":F7031389000A60\r\n"
- self.assertFalse(self._ascii.isFrameReady())
- self.assertFalse(self._ascii.checkFrame())
+ assert not self._ascii.isFrameReady()
+ assert not self._ascii.checkFrame()
self._ascii.addToFrame(msg)
- self.assertTrue(self._ascii.isFrameReady())
- self.assertTrue(self._ascii.checkFrame())
+ assert self._ascii.isFrameReady()
+ assert self._ascii.checkFrame()
self._ascii.advanceFrame()
- self.assertFalse(self._ascii.isFrameReady())
- self.assertFalse(self._ascii.checkFrame())
- self.assertEqual(b"", self._ascii.getFrame())
+ assert not self._ascii.isFrameReady()
+ assert not self._ascii.checkFrame()
+ assert not self._ascii.getFrame()
def test_ascii_framer_transaction_full(self):
"""Test a full ascii frame transaction"""
msg = b"sss:F7031389000A60\r\n"
pack = a2b_hex(msg[6:-4])
self._ascii.addToFrame(msg)
- self.assertTrue(self._ascii.checkFrame())
+ assert self._ascii.checkFrame()
result = self._ascii.getFrame()
- self.assertEqual(pack, result)
+ assert pack == result
self._ascii.advanceFrame()
def test_ascii_framer_transaction_half(self):
@@ -693,46 +681,46 @@ def test_ascii_framer_transaction_half(self):
msg2 = b"000A60\r\n"
pack = a2b_hex(msg1[6:] + msg2[:-4])
self._ascii.addToFrame(msg1)
- self.assertFalse(self._ascii.checkFrame())
+ assert not self._ascii.checkFrame()
result = self._ascii.getFrame()
- self.assertEqual(b"", result)
+ assert not result
self._ascii.addToFrame(msg2)
- self.assertTrue(self._ascii.checkFrame())
+ assert self._ascii.checkFrame()
result = self._ascii.getFrame()
- self.assertEqual(pack, result)
+ assert pack == result
self._ascii.advanceFrame()
def test_ascii_framer_populate(self):
"""Test a ascii frame packet build"""
request = ModbusRequest()
self._ascii.populateResult(request)
- self.assertEqual(0x00, request.unit_id)
+ assert not request.slave_id
def test_ascii_framer_packet(self):
"""Test a ascii frame packet build"""
old_encode = ModbusRequest.encode
ModbusRequest.encode = lambda self: b""
message = ModbusRequest()
- message.unit_id = 0xFF
+ message.slave_id = 0xFF
message.function_code = 0x01
expected = b":FF0100\r\n"
actual = self._ascii.buildPacket(message)
- self.assertEqual(expected, actual)
+ assert expected == actual
ModbusRequest.encode = old_encode
def test_ascii_process_incoming_packets(self):
"""Test ascii process incoming packet."""
mock_data = b":F7031389000A60\r\n"
- unit = 0x00
+ slave = 0x00
def mock_callback(_mock_data, *_args, **_kwargs):
"""Mock callback."""
- self._ascii.processIncomingPacket(mock_data, mock_callback, unit)
+ self._ascii.processIncomingPacket(mock_data, mock_callback, slave)
# Test failure:
- self._ascii.checkFrame = MagicMock(return_value=False)
- self._ascii.processIncomingPacket(mock_data, mock_callback, unit)
+ self._ascii.checkFrame = mock.MagicMock(return_value=False)
+ self._ascii.processIncomingPacket(mock_data, mock_callback, slave)
# ----------------------------------------------------------------------- #
# Binary tests
@@ -740,24 +728,24 @@ def mock_callback(_mock_data, *_args, **_kwargs):
def test_binary_framer_transaction_ready(self):
"""Test a binary frame transaction"""
msg = TEST_MESSAGE
- self.assertFalse(self._binary.isFrameReady())
- self.assertFalse(self._binary.checkFrame())
+ assert not self._binary.isFrameReady()
+ assert not self._binary.checkFrame()
self._binary.addToFrame(msg)
- self.assertTrue(self._binary.isFrameReady())
- self.assertTrue(self._binary.checkFrame())
+ assert self._binary.isFrameReady()
+ assert self._binary.checkFrame()
self._binary.advanceFrame()
- self.assertFalse(self._binary.isFrameReady())
- self.assertFalse(self._binary.checkFrame())
- self.assertEqual(b"", self._binary.getFrame())
+ assert not self._binary.isFrameReady()
+ assert not self._binary.checkFrame()
+ assert not self._binary.getFrame()
def test_binary_framer_transaction_full(self):
"""Test a full binary frame transaction"""
msg = TEST_MESSAGE
pack = msg[2:-3]
self._binary.addToFrame(msg)
- self.assertTrue(self._binary.checkFrame())
+ assert self._binary.checkFrame()
result = self._binary.getFrame()
- self.assertEqual(pack, result)
+ assert pack == result
self._binary.advanceFrame()
def test_binary_framer_transaction_half(self):
@@ -766,43 +754,43 @@ def test_binary_framer_transaction_half(self):
msg2 = b"\x00\x00\x05\x85\xC9\x7d"
pack = msg1[2:] + msg2[:-3]
self._binary.addToFrame(msg1)
- self.assertFalse(self._binary.checkFrame())
+ assert not self._binary.checkFrame()
result = self._binary.getFrame()
- self.assertEqual(b"", result)
+ assert not result
self._binary.addToFrame(msg2)
- self.assertTrue(self._binary.checkFrame())
+ assert self._binary.checkFrame()
result = self._binary.getFrame()
- self.assertEqual(pack, result)
+ assert pack == result
self._binary.advanceFrame()
def test_binary_framer_populate(self):
"""Test a binary frame packet build"""
request = ModbusRequest()
self._binary.populateResult(request)
- self.assertEqual(0x00, request.unit_id)
+ assert not request.slave_id
def test_binary_framer_packet(self):
"""Test a binary frame packet build"""
old_encode = ModbusRequest.encode
ModbusRequest.encode = lambda self: b""
message = ModbusRequest()
- message.unit_id = 0xFF
+ message.slave_id = 0xFF
message.function_code = 0x01
expected = b"\x7b\xff\x01\x81\x80\x7d"
actual = self._binary.buildPacket(message)
- self.assertEqual(expected, actual)
+ assert expected == actual
ModbusRequest.encode = old_encode
def test_binary_process_incoming_packet(self):
"""Test binary process incoming packet."""
mock_data = TEST_MESSAGE
- unit = 0x00
+ slave = 0x00
def mock_callback(_mock_data):
pass
- self._binary.processIncomingPacket(mock_data, mock_callback, unit)
+ self._binary.processIncomingPacket(mock_data, mock_callback, slave)
# Test failure:
- self._binary.checkFrame = MagicMock(return_value=False)
- self._binary.processIncomingPacket(mock_data, mock_callback, unit)
+ self._binary.checkFrame = mock.MagicMock(return_value=False)
+ self._binary.processIncomingPacket(mock_data, mock_callback, slave)
diff --git a/test/test_unix_socket.py b/test/test_unix_socket.py
index acb91f655..d41c1ab04 100755
--- a/test/test_unix_socket.py
+++ b/test/test_unix_socket.py
@@ -29,9 +29,9 @@ async def _helper_server(path_addon):
"""Run server."""
datablock = ModbusSequentialDataBlock(0x00, [17] * 100)
context = ModbusSlaveContext(
- di=datablock, co=datablock, hr=datablock, ir=datablock, unit=1
+ di=datablock, co=datablock, hr=datablock, ir=datablock, slave=1
)
- asyncio.create_task(
+ asyncio.create_task( # noqa: RUF006
StartAsyncUnixServer(
context=ModbusServerContext(slaves=context, single=True),
path=PATH + path_addon,
@@ -46,14 +46,14 @@ async def _helper_server(path_addon):
@pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows have a timeout problem.")
@pytest.mark.parametrize("path_addon", ["_1"])
async def test_unix_server(_mock_run_server):
- """Run async server with unit domain socket."""
+ """Run async server with unix domain socket."""
await asyncio.sleep(0.1)
@pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows have a timeout problem.")
@pytest.mark.parametrize("path_addon", ["_2"])
async def test_unix_async_client(path_addon, _mock_run_server):
- """Run async client with unit domain socket."""
+ """Run async client with unix domain socket."""
await asyncio.sleep(1)
client = AsyncModbusTcpClient(
HOST + path_addon,
diff --git a/test/test_utilities.py b/test/test_utilities.py
index bb5aefafa..d19743b3d 100644
--- a/test/test_utilities.py
+++ b/test/test_utilities.py
@@ -1,6 +1,5 @@
"""Test utilities."""
import struct
-import unittest
from pymodbus.utilities import (
checkCRC,
@@ -32,16 +31,29 @@ def __init__(self):
g_1 = dict_property(_test_master, 4)
-class SimpleUtilityTest(unittest.TestCase):
+class TestUtility:
"""Unittest for the pymod.utilities module."""
- def setUp(self):
+ def setup_method(self):
"""Initialize the test environment"""
- self.data = struct.pack(">HHHH", 0x1234, 0x2345, 0x3456, 0x4567)
- self.string = b"test the computation"
- self.bits = [True, False, True, False, True, False, True, False]
-
- def tearDown(self):
+ self.data = struct.pack( # pylint: disable=attribute-defined-outside-init
+ ">HHHH", 0x1234, 0x2345, 0x3456, 0x4567
+ )
+ self.string = ( # pylint: disable=attribute-defined-outside-init
+ b"test the computation"
+ )
+ self.bits = [ # pylint: disable=attribute-defined-outside-init
+ True,
+ False,
+ True,
+ False,
+ True,
+ False,
+ True,
+ False,
+ ]
+
+ def teardown_method(self):
"""Clean up the test environment"""
del self.bits
del self.string
@@ -49,44 +61,44 @@ def tearDown(self):
def test_dict_property(self):
"""Test all string <=> bit packing functions"""
result = DictPropertyTester()
- self.assertEqual(result.l_1, "a")
- self.assertEqual(result.l_2, "b")
- self.assertEqual(result.l_3, "c")
- self.assertEqual(result.s_1, "a")
- self.assertEqual(result.s_2, "b")
- self.assertEqual(result.g_1, "d")
+ assert result.l_1 == "a"
+ assert result.l_2 == "b"
+ assert result.l_3 == "c"
+ assert result.s_1 == "a"
+ assert result.s_2 == "b"
+ assert result.g_1 == "d"
for store in "l_1 l_2 l_3 s_1 s_2 g_1".split(" "):
setattr(result, store, "x")
- self.assertEqual(result.l_1, "x")
- self.assertEqual(result.l_2, "x")
- self.assertEqual(result.l_3, "x")
- self.assertEqual(result.s_1, "x")
- self.assertEqual(result.s_2, "x")
- self.assertEqual(result.g_1, "x")
+ assert result.l_1 == "x"
+ assert result.l_2 == "x"
+ assert result.l_3 == "x"
+ assert result.s_1 == "x"
+ assert result.s_2 == "x"
+ assert result.g_1 == "x"
def test_default_value(self):
"""Test all string <=> bit packing functions"""
- self.assertEqual(default(1), 0)
- self.assertEqual(default(1.1), 0.0)
- self.assertEqual(default(1 + 1j), 0j)
- self.assertEqual(default("string"), "")
- self.assertEqual(default([1, 2, 3]), [])
- self.assertEqual(default({1: 1}), {})
- self.assertEqual(default(True), False)
+ assert not default(1)
+ assert not default(1.1)
+ assert not default(1 + 1)
+ assert not default("string")
+ assert default([1, 2, 3]) == []
+ assert default({1: 1}) == {}
+ assert not default(True)
def test_bit_packing(self):
"""Test all string <=> bit packing functions"""
- self.assertEqual(unpack_bitstring(b"\x55"), self.bits)
- self.assertEqual(pack_bitstring(self.bits), b"\x55")
+ assert unpack_bitstring(b"\x55") == self.bits
+ assert pack_bitstring(self.bits) == b"\x55"
def test_longitudinal_redundancycheck(self):
"""Test the longitudinal redundancy check code"""
- self.assertTrue(checkLRC(self.data, 0x1C))
- self.assertTrue(checkLRC(self.string, 0x0C))
+ assert checkLRC(self.data, 0x1C)
+ assert checkLRC(self.string, 0x0C)
def test_cyclic_redundancy_check(self):
"""Test the cyclic redundancy check code"""
- self.assertTrue(checkCRC(self.data, 0xE2DB))
- self.assertTrue(checkCRC(self.string, 0x889E))
+ assert checkCRC(self.data, 0xE2DB)
+ assert checkCRC(self.string, 0x889E)
diff --git a/test/test_version.py b/test/test_version.py
deleted file mode 100644
index 17f329334..000000000
--- a/test/test_version.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""Test version."""
-import unittest
-
-from pymodbus import __version__ as pymodbus_version
-from pymodbus import __version_full__ as pymodbus_version_full
-from pymodbus.version import Version, version
-
-
-class ModbusVersionTest(unittest.TestCase):
- """Unittest for the pymodbus._version code."""
-
- def setUp(self):
- """Initialize the test environment"""
-
- def tearDown(self):
- """Clean up the test environment"""
-
- def test_version_class(self):
- """Test version class."""
- test_version = Version("test", 1, 2, 3, "sometag")
- self.assertEqual(test_version.short(), "1.2.3.sometag")
- self.assertEqual(str(test_version), "[test, version 1.2.3.sometag]")
- self.assertEqual(test_version.package, "test")
-
- self.assertEqual(pymodbus_version, version.short())
- self.assertEqual(pymodbus_version_full, str(version))
- self.assertEqual(version.package, "pymodbus")
diff --git a/test/transport/test_basic.py b/test/transport/test_basic.py
new file mode 100644
index 000000000..930c7aa94
--- /dev/null
+++ b/test/transport/test_basic.py
@@ -0,0 +1,504 @@
+"""Test transport."""
+import asyncio
+import os
+from unittest import mock
+
+import pytest
+from serial import SerialException
+
+from pymodbus.framer import ModbusFramer
+from pymodbus.transport.transport import BaseTransport
+
+
+class TestBaseTransport:
+ """Test transport module, base part."""
+
+ base_comm_name = "test comm"
+ base_reconnect_delay = 1
+ base_reconnect_delay_max = 3.5
+ base_timeout_connect = 2
+ base_framer = ModbusFramer
+ base_host = "test host"
+ base_port = 502
+ base_server_hostname = "server test host"
+ base_baudrate = 9600
+ base_bytesize = 8
+ base_parity = "e"
+ base_stopbits = 2
+ cwd = None
+
+ class dummy_transport(BaseTransport):
+ """Transport class for test."""
+
+ def __init__(self):
+ """Initialize."""
+ super().__init__(
+ TestBaseTransport.base_comm_name,
+ [
+ TestBaseTransport.base_reconnect_delay * 1000,
+ TestBaseTransport.base_reconnect_delay_max * 1000,
+ ],
+ TestBaseTransport.base_timeout_connect * 1000,
+ TestBaseTransport.base_framer,
+ None,
+ None,
+ None,
+ )
+ self.abort = mock.MagicMock()
+ self.close = mock.MagicMock()
+
+ @classmethod
+ async def setup_BaseTransport(cls):
+ """Create base object."""
+ base = BaseTransport(
+ cls.base_comm_name,
+ (cls.base_reconnect_delay * 1000, cls.base_reconnect_delay_max * 1000),
+ cls.base_timeout_connect * 1000,
+ cls.base_framer,
+ mock.MagicMock(),
+ mock.MagicMock(),
+ mock.MagicMock(),
+ )
+ params = base.CommParamsClass(
+ done=True,
+ comm_name=cls.base_comm_name,
+ reconnect_delay=cls.base_reconnect_delay,
+ reconnect_delay_max=cls.base_reconnect_delay_max,
+ timeout_connect=cls.base_timeout_connect,
+ framer=cls.base_framer,
+ )
+ cls.cwd = os.getcwd().split("/")[-1]
+ if cls.cwd == "transport":
+ cls.cwd = "../../"
+ elif cls.cwd == "test":
+ cls.cwd = "../"
+ else:
+ cls.cwd = ""
+ cls.cwd = cls.cwd + "examples/certificates/pymodbus."
+ return base, params
+
+ async def test_init(self):
+ """Test init()"""
+ base, params = await self.setup_BaseTransport()
+ params.done = False
+ assert base.comm_params == params
+
+ assert base.cb_connection_made
+ assert base.cb_connection_lost
+ assert base.cb_handle_data
+ assert not base.reconnect_delay_current
+ assert not base.reconnect_timer
+
+ async def test_property_done(self):
+ """Test done property"""
+ base, params = await self.setup_BaseTransport()
+ base.comm_params.check_done()
+ with pytest.raises(RuntimeError):
+ base.comm_params.check_done()
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ @pytest.mark.parametrize("setup_server", [True, False])
+ async def test_properties_unix(self, setup_server):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ base.setup_unix(setup_server, self.base_host)
+ params.host = self.base_host
+ assert base.comm_params == params
+ assert base.call_connect_listen
+
+ @pytest.mark.skipif(
+ not pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ @pytest.mark.parametrize("setup_server", [True, False])
+ async def test_properties_unix_windows(self, setup_server):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ with pytest.raises(RuntimeError):
+ base.setup_unix(setup_server, self.base_host)
+
+ @pytest.mark.parametrize("setup_server", [True, False])
+ async def test_properties_tcp(self, setup_server):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ base.setup_tcp(setup_server, self.base_host, self.base_port)
+ params.host = self.base_host
+ params.port = self.base_port
+ assert base.comm_params == params
+ assert base.call_connect_listen
+
+ @pytest.mark.parametrize("setup_server", [True, False])
+ async def test_properties_udp(self, setup_server):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ base.setup_udp(setup_server, self.base_host, self.base_port)
+ params.host = self.base_host
+ params.port = self.base_port
+ assert base.comm_params == params
+ assert base.call_connect_listen
+
+ @pytest.mark.parametrize("setup_server", [True, False])
+ @pytest.mark.parametrize("sslctx", [None, "test ctx"])
+ async def test_properties_tls(self, setup_server, sslctx):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ with mock.patch("pymodbus.transport.transport.ssl.SSLContext"):
+ base.setup_tls(
+ setup_server,
+ self.base_host,
+ self.base_port,
+ sslctx,
+ None,
+ None,
+ None,
+ self.base_server_hostname,
+ )
+ params.host = self.base_host
+ params.port = self.base_port
+ params.server_hostname = self.base_server_hostname
+ params.ssl = sslctx if sslctx else base.comm_params.ssl
+ assert base.comm_params == params
+ assert base.call_connect_listen
+
+ @pytest.mark.parametrize("setup_server", [True, False])
+ async def test_properties_serial(self, setup_server):
+ """Test properties."""
+ base, params = await self.setup_BaseTransport()
+ base.setup_serial(
+ setup_server,
+ self.base_host,
+ self.base_baudrate,
+ self.base_bytesize,
+ self.base_parity,
+ self.base_stopbits,
+ )
+ params.host = self.base_host
+ params.baudrate = self.base_baudrate
+ params.bytesize = self.base_bytesize
+ params.parity = self.base_parity
+ params.stopbits = self.base_stopbits
+ assert base.comm_params == params
+ assert base.call_connect_listen
+
+ async def test_with_magic(self):
+ """Test magic."""
+ base, _params = await self.setup_BaseTransport()
+ base.close = mock.MagicMock()
+ async with base:
+ pass
+ base.close.assert_called_once()
+
+ async def test_str_magic(self):
+ """Test magic."""
+ base, _params = await self.setup_BaseTransport()
+ assert str(base) == f"BaseTransport({self.base_comm_name})"
+
+ async def test_connection_made(self):
+ """Test connection_made()."""
+ base, params = await self.setup_BaseTransport()
+ transport = self.dummy_transport()
+ base.connection_made(transport)
+ assert base.transport == transport
+ assert not base.recv_buffer
+ assert not base.reconnect_timer
+ assert base.reconnect_delay_current == params.reconnect_delay
+ base.cb_connection_made.assert_called_once()
+ base.cb_connection_lost.assert_not_called()
+ base.cb_handle_data.assert_not_called()
+ base.close()
+
+ async def test_connection_lost(self):
+ """Test connection_lost()."""
+ base, params = await self.setup_BaseTransport()
+ transport = self.dummy_transport()
+ base.connection_lost(transport)
+ assert not base.transport
+ assert not base.recv_buffer
+ assert not base.reconnect_timer
+ assert not base.reconnect_delay_current
+ base.cb_connection_made.assert_not_called()
+ base.cb_handle_data.assert_not_called()
+ base.cb_connection_lost.assert_called_once()
+ # reconnect is only after a successful connect
+ base.connection_made(transport)
+ base.connection_lost(transport)
+ assert base.reconnect_timer
+ assert not base.transport
+ assert not base.recv_buffer
+ assert base.reconnect_timer
+ assert base.reconnect_delay_current == 2 * params.reconnect_delay
+ base.cb_connection_lost.call_count == 2
+ base.close()
+ assert not base.reconnect_timer
+
+ async def test_eof_received(self):
+ """Test connection_lost()."""
+ base, params = await self.setup_BaseTransport()
+ self.dummy_transport()
+ base.eof_received()
+ assert not base.transport
+ assert not base.recv_buffer
+ assert not base.reconnect_timer
+ assert not base.reconnect_delay_current
+
+ async def test_close(self):
+ """Test close()."""
+ base, _params = await self.setup_BaseTransport()
+ transport = self.dummy_transport()
+ base.connection_made(transport)
+ base.cb_connection_made.reset_mock()
+ base.cb_connection_lost.reset_mock()
+ base.cb_handle_data.reset_mock()
+ base.recv_buffer = b"abc"
+ base.reconnect_timer = mock.MagicMock()
+ base.close()
+ transport.abort.assert_called_once()
+ transport.close.assert_called_once()
+ base.cb_connection_made.assert_not_called()
+ base.cb_connection_lost.assert_not_called()
+ base.cb_handle_data.assert_not_called()
+ assert not base.recv_buffer
+ assert not base.reconnect_timer
+
+ async def test_reset_delay(self):
+ """Test reset_delay()."""
+ base, _params = await self.setup_BaseTransport()
+ base.reconnect_delay_current = self.base_reconnect_delay + 1
+ base.reset_delay()
+ assert base.reconnect_delay_current == self.base_reconnect_delay
+
+ async def test_datagram(self):
+ """Test datagram_received()."""
+ base, _params = await self.setup_BaseTransport()
+ base.data_received = mock.MagicMock()
+ base.datagram_received(b"abc", "127.0.0.1")
+ base.data_received.assert_called_once()
+
+ async def test_data(self):
+ """Test data_received."""
+ base, _params = await self.setup_BaseTransport()
+ base.cb_handle_data = mock.MagicMock(return_value=2)
+ base.data_received(b"123456")
+ base.cb_handle_data.assert_called_once()
+ assert base.recv_buffer == b"3456"
+ base.data_received(b"789")
+ assert base.recv_buffer == b"56789"
+
+ async def test_send(self):
+ """Test send()."""
+ base, _params = await self.setup_BaseTransport()
+ base.transport = mock.AsyncMock()
+ await base.send(b"abc")
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ async def test_connect_unix(self):
+ """Test connect_unix()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_unix(False, self.base_host)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_unix_connection = mocker
+ mocker.side_effect = FileNotFoundError("testing")
+ assert await base.transport_connect() == (None, None)
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert mocker.return_value == await base.transport_connect()
+ base.close.called_once()
+
+ async def test_connect_tcp(self):
+ """Test connect_tcp()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_tcp(False, self.base_host, self.base_port)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_connection = mocker
+ mocker.side_effect = asyncio.TimeoutError("testing")
+ assert await base.transport_connect() == (None, None)
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert mocker.return_value == await base.transport_connect()
+ base.close.assert_called_once()
+
+ async def test_connect_tls(self):
+ """Test connect_tcls()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_tls(
+ False,
+ self.base_host,
+ self.base_port,
+ "no ssl",
+ None,
+ None,
+ None,
+ self.base_server_hostname,
+ )
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_connection = mocker
+ mocker.side_effect = asyncio.TimeoutError("testing")
+ assert await base.transport_connect() == (None, None)
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert mocker.return_value == await base.transport_connect()
+ base.close.assert_called_once()
+
+ async def test_connect_udp(self):
+ """Test connect_udp()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_udp(False, self.base_host, self.base_port)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_datagram_endpoint = mocker
+ mocker.side_effect = asyncio.TimeoutError("testing")
+ assert await base.transport_connect() == (None, None)
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert mocker.return_value == await base.transport_connect()
+ base.close.assert_called_once()
+
+ async def test_connect_serial(self):
+ """Test connect_serial()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_serial(
+ False,
+ self.base_host,
+ self.base_baudrate,
+ self.base_bytesize,
+ self.base_parity,
+ self.base_stopbits,
+ )
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ with mock.patch(
+ "pymodbus.transport.transport.create_serial_connection", new=mocker
+ ):
+ mocker.side_effect = asyncio.TimeoutError("testing")
+ assert await base.transport_connect() == (None, None)
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert mocker.return_value == await base.transport_connect()
+ base.close.assert_called_once()
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ async def test_listen_unix(self):
+ """Test listen_unix()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_unix(True, self.base_host)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_unix_server = mocker
+ mocker.side_effect = OSError("testing")
+ assert await base.transport_listen() is None
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = 117
+ assert mocker.return_value == await base.transport_listen()
+ base.close.assert_called_once()
+
+ async def test_listen_tcp(self):
+ """Test listen_tcp()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_tcp(True, self.base_host, self.base_port)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_server = mocker
+ mocker.side_effect = OSError("testing")
+ assert await base.transport_listen() is None
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = 117
+ assert mocker.return_value == await base.transport_listen()
+ base.close.assert_called_once()
+
+ async def test_listen_tls(self):
+ """Test listen_tls()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_tls(
+ True,
+ self.base_host,
+ self.base_port,
+ "no ssl",
+ None,
+ None,
+ None,
+ self.base_server_hostname,
+ )
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_server = mocker
+ mocker.side_effect = OSError("testing")
+ assert await base.transport_listen() is None
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = 117
+ assert mocker.return_value == await base.transport_listen()
+ base.close.assert_called_once()
+
+ async def test_listen_udp(self):
+ """Test listen_udp()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_udp(True, self.base_host, self.base_port)
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ base.loop.create_datagram_endpoint = mocker
+ mocker.side_effect = OSError("testing")
+ assert await base.transport_listen() is None
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = (117, 118)
+ assert await base.transport_listen() == 117
+ base.close.assert_called_once()
+
+ async def test_listen_serial(self):
+ """Test listen_serial()."""
+ base, _params = await self.setup_BaseTransport()
+ base.setup_serial(
+ True,
+ self.base_host,
+ self.base_baudrate,
+ self.base_bytesize,
+ self.base_parity,
+ self.base_stopbits,
+ )
+ base.close = mock.Mock()
+ mocker = mock.AsyncMock()
+
+ with mock.patch(
+ "pymodbus.transport.transport.create_serial_connection", new=mocker
+ ):
+ mocker.side_effect = SerialException("testing")
+ assert await base.transport_listen() is None
+ base.close.assert_called_once()
+ mocker.side_effect = None
+
+ mocker.return_value = 117
+ assert await base.transport_listen() == 117
+ base.close.assert_called_once()
diff --git a/test/transport/test_comm.py b/test/transport/test_comm.py
new file mode 100644
index 000000000..e810fa332
--- /dev/null
+++ b/test/transport/test_comm.py
@@ -0,0 +1,382 @@
+"""Test transport."""
+import asyncio
+import os
+import sys
+import time
+from tempfile import gettempdir
+
+import pytest
+
+from pymodbus.framer import ModbusFramer, ModbusSocketFramer
+from pymodbus.transport.transport import BaseTransport
+
+
+class TestCommTransport:
+ """Test for the transport module."""
+
+ cwd = None
+
+ @classmethod
+ def setup_CWD(cls):
+ """Get path to certificates."""
+ cls.cwd = os.getcwd().split("/")[-1]
+ if cls.cwd == "transport":
+ cls.cwd = "../../"
+ elif cls.cwd == "test":
+ cls.cwd = "../"
+ else:
+ cls.cwd = ""
+ cls.cwd = cls.cwd + "examples/certificates/pymodbus."
+
+ class dummy_transport(BaseTransport):
+ """Transport class for test."""
+
+ def cb_connection_made(self):
+ """Handle callback."""
+
+ def cb_connection_lost(self, _exc):
+ """Handle callback."""
+
+ def cb_handle_data(self, _data):
+ """Handle callback."""
+ return 0
+
+ def __init__(self, framer: ModbusFramer, comm_name="test comm"):
+ """Initialize."""
+ super().__init__(
+ comm_name,
+ [2500, 9000],
+ 2000,
+ framer,
+ self.cb_connection_made,
+ self.cb_connection_lost,
+ self.cb_handle_data,
+ )
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connect_unix(self):
+ """Test connect_unix()."""
+ client = self.dummy_transport(ModbusSocketFramer)
+ domain_socket = "/domain_unix"
+ client.setup_unix(False, domain_socket)
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ domain_socket = gettempdir() + "/domain_unix"
+ client.setup_unix(False, domain_socket)
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connect_tcp(self):
+ """Test connect_tcp()."""
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tcp(False, "142.250.200.78", 502)
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tcp(False, "localhost", 5001)
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connect_tls(self):
+ """Test connect_tls()."""
+ self.setup_CWD()
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tls(
+ False,
+ "142.250.200.78",
+ 502,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tls(
+ False,
+ "127.0.0.1",
+ 5001,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connect_serial(self):
+ """Test connect_serial()."""
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_serial(
+ False,
+ "no_port",
+ 9600,
+ 8,
+ "E",
+ 2,
+ )
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_serial(
+ False,
+ "unix:/localhost:5001",
+ 9600,
+ 8,
+ "E",
+ 2,
+ )
+ start = time.time()
+ assert await client.transport_connect() == (None, None)
+ delta = time.time() - start
+ assert delta < client.comm_params.timeout_connect * 1.2
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_listen_unix(self):
+ """Test listen_unix()."""
+ server = self.dummy_transport(ModbusSocketFramer)
+ domain_socket = "/test_unix_"
+ server.setup_unix(True, domain_socket)
+ assert not await server.transport_listen()
+ assert not server.transport
+
+ server = self.dummy_transport(ModbusSocketFramer)
+ domain_socket = gettempdir() + "/test_unix_" + str(time.time())
+ server.setup_unix(True, domain_socket)
+ assert await server.transport_listen()
+ assert server.transport
+ server.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_listen_tcp(self):
+ """Test listen_tcp()."""
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_tcp(True, "10.0.0.1", 5101)
+ assert not await server.transport_listen()
+ assert not server.transport
+
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_tcp(True, "localhost", 5101)
+ assert await server.transport_listen()
+ assert server.transport
+ server.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_listen_tls(self):
+ """Test listen_tls()."""
+ self.setup_CWD()
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_tls(
+ True,
+ "10.0.0.1",
+ 5101,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ assert not await server.transport_listen()
+ assert not server.transport
+
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_tls(
+ True,
+ "127.0.0.1",
+ 5101,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ assert await server.transport_listen()
+ assert server.transport
+ server.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_listen_udp(self):
+ """Test listen_udp()."""
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_udp(True, "10.0.0.1", 5101)
+ assert not await server.transport_listen()
+ assert not server.transport
+
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_udp(True, "localhost", 5101)
+ assert await server.transport_listen()
+ assert server.transport
+ server.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_listen_serial(self):
+ """Test listen_serial()."""
+ server = self.dummy_transport(ModbusSocketFramer)
+ server.setup_serial(
+ True,
+ "no port",
+ 9600,
+ 8,
+ "E",
+ 2,
+ )
+ assert not await server.transport_listen()
+ assert not server.transport
+
+ # there are no positive test, since there are no standard tty port
+
+ @pytest.mark.skipif(
+ pytest.IS_WINDOWS, reason="Windows do not support unix sockets."
+ )
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connected_unix(self):
+ """Test listen/connect unix()."""
+ server_protocol = self.dummy_transport(ModbusSocketFramer)
+ domain_socket = gettempdir() + "/test_unix_" + str(time.time())
+ server_protocol.setup_unix(True, domain_socket)
+ await server_protocol.transport_listen()
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_unix(False, domain_socket)
+ assert await client.transport_connect() != (None, None)
+ client.close()
+ server_protocol.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connected_tcp(self):
+ """Test listen/connect tcp()."""
+ server_protocol = self.dummy_transport(ModbusSocketFramer)
+ server_protocol.setup_tcp(True, "localhost", 5101)
+ assert await server_protocol.transport_listen()
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tcp(False, "localhost", 5101)
+ assert await client.transport_connect() != (None, None)
+ client.close()
+ server_protocol.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connected_tls(self):
+ """Test listen/connect tls()."""
+ self.setup_CWD()
+ server_protocol = self.dummy_transport(ModbusSocketFramer)
+ server_protocol.setup_tls(
+ True,
+ "127.0.0.1",
+ 5102,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ assert await server_protocol.transport_listen()
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_tls(
+ False,
+ "127.0.0.1",
+ 5102,
+ None,
+ self.cwd + "crt",
+ self.cwd + "key",
+ None,
+ "localhost",
+ )
+ assert await client.transport_connect() != (None, None)
+ client.close()
+ server_protocol.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connected_udp(self):
+ """Test listen/connect udp()."""
+ server_protocol = self.dummy_transport(ModbusSocketFramer)
+ server_protocol.setup_udp(True, "localhost", 5101)
+ transport = await server_protocol.transport_listen()
+ assert transport
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_udp(False, "localhost", 5101)
+ assert await client.transport_connect() != (None, None)
+ client.close()
+ server_protocol.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connected_serial(self):
+ """Test listen/connect serial()."""
+ server_protocol = self.dummy_transport(ModbusSocketFramer)
+ server_protocol.setup_tcp(True, "localhost", 5101)
+ assert await server_protocol.transport_listen()
+
+ client = self.dummy_transport(ModbusSocketFramer)
+ client.setup_serial(
+ False,
+ "unix:localhost:5001",
+ 9600,
+ 8,
+ "E",
+ 2,
+ )
+ assert await client.transport_connect() == (None, None)
+ client.close()
+ server_protocol.close()
+
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_connect_reconnect(self):
+ """Test connect() reconnecting."""
+ server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode")
+ server.setup_tcp(True, "localhost", 5101)
+ await server.transport_listen()
+ assert server.transport
+
+ client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode")
+ client.setup_tcp(False, "localhost", 5101)
+ assert await client.transport_connect() != (None, None)
+ server.close()
+ count = 100
+ while client.transport and count:
+ await asyncio.sleep(0.1)
+ count -= 1
+ if not sys.platform.startswith("win"):
+ assert not client.transport
+ assert client.reconnect_timer
+ assert (
+ client.reconnect_delay_current == 2 * client.comm_params.reconnect_delay
+ )
+ await asyncio.sleep(client.reconnect_delay_current * 1.2)
+ assert client.transport
+ assert client.reconnect_timer
+ assert client.reconnect_delay_current == client.comm_params.reconnect_delay
+ client.close()
+ server.close()
diff --git a/test/transport/test_data.py b/test/transport/test_data.py
new file mode 100644
index 000000000..035e3afb9
--- /dev/null
+++ b/test/transport/test_data.py
@@ -0,0 +1,55 @@
+"""Test transport."""
+import asyncio
+
+import pytest
+
+from pymodbus.framer import ModbusFramer, ModbusSocketFramer
+from pymodbus.transport.transport import BaseTransport
+
+
+class TestDataTransport:
+ """Test for the transport module."""
+
+ class dummy_transport(BaseTransport):
+ """Transport class for test."""
+
+ def cb_connection_made(self):
+ """Handle callback."""
+
+ def cb_connection_lost(self, _exc):
+ """Handle callback."""
+
+ def cb_handle_data(self, _data):
+ """Handle callback."""
+ return 0
+
+ def __init__(self, framer: ModbusFramer, comm_name="test comm"):
+ """Initialize."""
+ super().__init__(
+ comm_name,
+ [2500, 9000],
+ 2000,
+ framer,
+ self.cb_connection_made,
+ self.cb_connection_lost,
+ self.cb_handle_data,
+ )
+
+ @pytest.mark.skipif(pytest.IS_WINDOWS, reason="Windows problem.")
+ @pytest.mark.xdist_group(name="server_serialize")
+ async def test_client_send(self):
+ """Test connect() reconnecting."""
+ server = self.dummy_transport(ModbusSocketFramer, comm_name="server mode")
+ server.setup_tcp(True, "localhost", 5101)
+ await server.transport_listen()
+ assert server.transport
+
+ client = self.dummy_transport(ModbusSocketFramer, comm_name="client mode")
+ client.setup_tcp(False, "localhost", 5101)
+ assert await client.transport_connect() != (None, None)
+ await client.send(b"ABC")
+ await asyncio.sleep(2)
+ assert server.recv_buffer == b"ABC"
+ await server.send(b"DEF")
+ await asyncio.sleep(2)
+ assert client.recv_buffer == b"DEF"