Skip to content

Commit

Permalink
Combine conftest to a central file (#1962)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen authored Feb 3, 2024
1 parent 0a9d23e commit dedcb8b
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 229 deletions.
206 changes: 195 additions & 11 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
"""Configure pytest."""
import asyncio
import os
import platform
import sys
from collections import deque
from threading import enumerate as thread_enumerate
from unittest import mock

import pytest
import pytest_asyncio

from pymodbus.datastore import ModbusBaseSlaveContext
from pymodbus.server import ServerAsyncStop
from pymodbus.transport import NULLMODEM_HOST, CommParams, CommType, ModbusProtocol
from pymodbus.transport.transport import NullModem


sys.path.extend(["examples", "../examples", "../../examples"])

from examples.server_async import ( # noqa: E402 # pylint: disable=wrong-import-position
run_async_server,
setup_server,
)


def pytest_configure():
"""Configure pytest."""
pytest.IS_DARWIN = platform.system().lower() == "darwin"
Expand Down Expand Up @@ -42,6 +55,188 @@ def get_base_ports():
return BASE_PORTS


@pytest.fixture(name="use_comm_type")
def prepare_dummy_use_comm_type():
"""Return default comm_type."""
return CommType.TCP


@pytest.fixture(name="use_host")
def define_use_host():
"""Set default host."""
return NULLMODEM_HOST


@pytest.fixture(name="use_cls")
def prepare_commparams_server(use_port, use_host, use_comm_type):
"""Prepare CommParamsClass object."""
if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL:
use_host = f"{NULLMODEM_HOST}:{use_port}"
return CommParams(
comm_name="test comm",
comm_type=use_comm_type,
reconnect_delay=0,
reconnect_delay_max=0,
timeout_connect=0,
source_address=(use_host, use_port),
baudrate=9600,
bytesize=8,
parity="E",
stopbits=2,
)


@pytest.fixture(name="use_clc")
def prepare_commparams_client(use_port, use_host, use_comm_type):
"""Prepare CommParamsClass object."""
if use_host == NULLMODEM_HOST and use_comm_type == CommType.SERIAL:
use_host = f"{NULLMODEM_HOST}:{use_port}"
timeout = 10 if not pytest.IS_WINDOWS else 2
return CommParams(
comm_name="test comm",
comm_type=use_comm_type,
reconnect_delay=0.1,
reconnect_delay_max=0.35,
timeout_connect=timeout,
host=use_host,
port=use_port,
baudrate=9600,
bytesize=8,
parity="E",
stopbits=2,
)


@pytest.fixture(name="client")
def prepare_protocol(use_clc):
"""Prepare transport object."""
transport = ModbusProtocol(use_clc, False)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_clc.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_clc.generate_ssl(
False, certfile=cwd + "crt", keyfile=cwd + "key"
)
if use_clc.comm_type == CommType.SERIAL:
transport.comm_params.host = f"socket://localhost:{transport.comm_params.port}"
return transport


@pytest.fixture(name="server")
def prepare_transport_server(use_cls):
"""Prepare transport object."""
transport = ModbusProtocol(use_cls, True)
transport.callback_connected = mock.Mock()
transport.callback_disconnected = mock.Mock()
transport.callback_data = mock.Mock(return_value=0)
if use_cls.comm_type == CommType.TLS:
cwd = os.path.dirname(__file__) + "/../examples/certificates/pymodbus."
transport.comm_params.sslctx = use_cls.generate_ssl(
True, certfile=cwd + "crt", keyfile=cwd + "key"
)
return transport


class DummyProtocol(ModbusProtocol):
"""Use in connection_made calls."""

def __init__(self, is_server=False): # pylint: disable=super-init-not-called
"""Initialize."""
self.comm_params = CommParams()
self.transport = None
self.is_server = is_server
self.is_closing = False
self.data = b""
self.connection_made = mock.Mock()
self.connection_lost = mock.Mock()
self.reconnect_task: asyncio.Task = None

def handle_new_connection(self):
"""Handle incoming connect."""
if not self.is_server:
# Clients reuse the same object.
return self
return DummyProtocol()

def close(self):
"""Simulate close."""
self.is_closing = True

def data_received(self, data):
"""Call when some data is received."""
self.data += data


@pytest.fixture(name="dummy_protocol")
def prepare_dummy_protocol():
"""Return transport object."""
return DummyProtocol


@pytest.fixture(name="mock_clc")
def define_commandline_client(
use_comm,
use_framer,
use_port,
use_host,
):
"""Define commandline."""
my_port = str(use_port)
cmdline = ["--comm", use_comm, "--framer", use_framer, "--timeout", "0.1"]
if use_comm == "serial":
if use_host == NULLMODEM_HOST:
use_host = f"{use_host}:{my_port}"
else:
use_host = f"socket://{use_host}:{my_port}"
cmdline.extend(["--baudrate", "9600", "--port", use_host])
else:
cmdline.extend(["--port", my_port, "--host", use_host])
return cmdline


@pytest.fixture(name="mock_cls")
def define_commandline_server(
use_comm,
use_framer,
use_port,
use_host,
):
"""Define commandline."""
my_port = str(use_port)
cmdline = [
"--comm",
use_comm,
"--framer",
use_framer,
]
if use_comm == "serial":
if use_host == NULLMODEM_HOST:
use_host = f"{use_host}:{my_port}"
else:
use_host = f"socket://{use_host}:{my_port}"
cmdline.extend(["--baudrate", "9600", "--port", use_host])
else:
cmdline.extend(["--port", my_port, "--host", use_host])
return cmdline


@pytest_asyncio.fixture(name="mock_server")
async def _run_server(
mock_cls,
):
"""Run server."""
run_args = setup_server(cmdline=mock_cls)
task = asyncio.create_task(run_async_server(run_args))
task.set_name("mock_server")
await asyncio.sleep(0.1)
yield mock_cls
await ServerAsyncStop()
task.cancel()
await task


@pytest.fixture(name="system_health_check", autouse=True)
async def _check_system_health():
"""Check Thread, asyncio.task and NullModem for leftovers."""
Expand Down Expand Up @@ -193,14 +388,3 @@ def sendto(self, msg, *_args):
def setblocking(self, _flag):
"""Set blocking."""
return None


_CURRENT_PORT = 5200


@pytest.fixture(name="use_port")
def get_port():
"""Get next port."""
global _CURRENT_PORT # pylint: disable=global-statement
_CURRENT_PORT += 1
return _CURRENT_PORT
85 changes: 0 additions & 85 deletions test/sub_examples/conftest.py

This file was deleted.

Loading

0 comments on commit dedcb8b

Please sign in to comment.