Skip to content

Commit

Permalink
Fix up xdist concurrent handling logic
Browse files Browse the repository at this point in the history
- Large refactor of plugin code to better isolate per session vs per
worker, and fix test collection logic
  • Loading branch information
joshuatz committed Sep 15, 2024
1 parent e30651b commit fb29cbf
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 39 deletions.
6 changes: 5 additions & 1 deletion Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ tasks:
- task: lint:types
test:
deps: [_verify_python_venv, install]
cmd: poetry run pytest -n auto
cmd: |
# Clear cache files
rm -rf $PACKAGE_DIR/testing/.pytest_run_cache
# Run pytest
poetry run pytest -n auto
#============================================================#
#================= SECTION_HEADING ==========================#
#============================================================#
170 changes: 133 additions & 37 deletions django_utils_lib/testing/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
import csv
import json
import os
import pathlib
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Union,
cast,
)

import pytest
import xdist
import xdist.dsession
import xdist.workermanage
from constants import PACKAGE_NAME
from filelock import FileLock
from typing_extensions import NotRequired, TypedDict

from django_utils_lib.logger import build_heading_block, pkg_logger
from django_utils_lib.testing.utils import PytestNodeID, validate_requirement_tagging
from django_utils_lib.testing.utils import PytestNodeID, is_main_pytest_runner, validate_requirement_tagging

BASE_DIR = Path(__file__).resolve().parent

# Due to the parallelized nature of xdist (we our library consumer might or might
# not be using), we are going to use a file-based system for implementing both
# a concurrency lock, as well as a way to easily share the metadata across
# processes.
temp_file_path = os.path.join(BASE_DIR, "test.temp.json")
temp_file_lock_path = f"{temp_file_path}.lock"
file_lock = FileLock(temp_file_lock_path)


TestStatus = Literal["PASS", "FAIL", ""]

Expand Down Expand Up @@ -103,6 +103,36 @@ class PluginConfigurationItem(TypedDict):
}


class InternalSessionConfig(TypedDict):
global_session_id: str
temp_shared_session_dir_path: str


# Note: Redundant typing of InternalSessionConfig, but likely unavoidable
# due to lack of type-coercion features in Python types
@dataclass
class InternalSessionConfigDataClass:
global_session_id: str
temp_shared_session_dir_path: str


class InternalWorkerConfig(InternalSessionConfig):
# These values are provided by xdist automatically
workerid: str
"""
Auto-generated worker ID (`gw0`, `gw1`, etc.)
"""
workercount: int
testrunuid: str
# Our own injected values
temp_worker_dir_path: str


@dataclass
class WorkerConfigInstance:
workerinput: InternalWorkerConfig


class CollectedTestMetadata(TypedDict):
"""
Metadata that is collected for each test "node"
Expand Down Expand Up @@ -138,11 +168,26 @@ class CollectedTests:
File-backed data-store for collected test info
"""

def __init__(self, run_id: str) -> None:
"""
Args:
run_id: This should be a global session ID, unless you want to isolate results by worker
"""
self.tmp_dir_path = os.path.join(BASE_DIR, ".pytest_run_cache", run_id)
os.makedirs(self.tmp_dir_path, exist_ok=True)
# Due to the parallelized nature of xdist (we our library consumer might or might
# not be using), we are going to use a file-based system for implementing both
# a concurrency lock, as well as a way to easily share the metadata across
# processes.
self.temp_file_path = os.path.join(self.tmp_dir_path, "test.temp.json")
self.temp_file_lock_path = f"{self.temp_file_path}.lock"
self.file_lock = FileLock(self.temp_file_lock_path)

def _get_data(self) -> CollectedTestsMapping:
with file_lock:
if not os.path.exists(temp_file_path):
with self.file_lock:
if not os.path.exists(self.temp_file_path):
return {}
with open(temp_file_path, "r") as f:
with open(self.temp_file_path, "r") as f:
return json.load(f)

def __getitem__(self, node_id: PytestNodeID) -> CollectedTestMetadata:
Expand All @@ -151,21 +196,18 @@ def __getitem__(self, node_id: PytestNodeID) -> CollectedTestMetadata:
def __setitem__(self, node_id: str, item: CollectedTestMetadata):
updated_data = self._get_data()
updated_data[node_id] = item
with file_lock:
with open(temp_file_path, "w") as f:
with self.file_lock:
with open(self.temp_file_path, "w") as f:
json.dump(updated_data, f)

def update_test_status(self, node_id: PytestNodeID, updated_status: TestStatus):
updated_data = self._get_data()
updated_data[node_id]["status"] = updated_status
with file_lock:
with open(temp_file_path, "w") as f:
with self.file_lock:
with open(self.temp_file_path, "w") as f:
json.dump(updated_data, f)


collected_tests = CollectedTests()


@pytest.hookimpl()
def pytest_addoption(parser: pytest.Parser):
# Register all config key-pairs with INI parser
Expand All @@ -175,58 +217,114 @@ def pytest_addoption(parser: pytest.Parser):

@pytest.hookimpl()
def pytest_configure(config: pytest.Config):
if hasattr(config, "workerinput"):
if not is_main_pytest_runner(config):
return

# Register markers
config.addinivalue_line("markers", "requirements(requirements: List[str]): Attach requirements to test")

# Register plugin
plugin = CustomPytestPlugin(config)
config.pluginmanager.register(plugin)

@pytest.hookimpl()
def pytest_sessionstart(session: pytest.Session):
if is_main_pytest_runner(session):
# If we are on the main runner, this is either a non-xdist run, or
# this is the main xdist process, before nodes been distributed.
# Regardless, we should set up a shared temporary directory, which can
# be shared among all n{0,} nodes
global_session_id = uuid.uuid4().hex
temp_shared_session_dir_path = os.path.join(BASE_DIR, ".pytest_run_cache", global_session_id)
pathlib.Path(temp_shared_session_dir_path).mkdir(parents=True, exist_ok=True)
session_config = cast(InternalSessionConfigDataClass, session.config)
session_config.global_session_id = global_session_id
session_config.temp_shared_session_dir_path = temp_shared_session_dir_path

plugin = CustomPytestPlugin(session.config)
session.config.pluginmanager.register(plugin)
pkg_logger.debug(f"{PACKAGE_NAME} plugin registered")
plugin.auto_engage_debugger()


def pytest_configure_node(node: xdist.workermanage.WorkerController):
"""
Special xdist-only hook, which is called as a node is configured, before instantiation & distribution
This hook only runs on the main process (not workers), and is skipped entirely if xdist is not being used
"""
worker_id: str = node.workerinput["workerid"]

# Retrieve global shared session config
session_config = cast(InternalSessionConfigDataClass, node.config)
temp_shared_session_dir_path = session_config.temp_shared_session_dir_path

# Construct worker-scoped temp directory
temp_worker_dir_path = os.path.join(temp_shared_session_dir_path, worker_id)
pathlib.Path(temp_worker_dir_path).mkdir(parents=True, exist_ok=True)

# Copy worker-specific, as well as shared config values, into the node config
node.workerinput["temp_worker_dir_path"] = temp_worker_dir_path
node.workerinput["temp_shared_session_dir_path"] = temp_shared_session_dir_path
node.workerinput["global_session_id"] = session_config.global_session_id


class CustomPytestPlugin:
# Tell Pytest that this is not a test class
__test__ = False

def __init__(self, pytest_config: pytest.Config) -> None:
self.pytest_config = pytest_config
self.collected_tests = CollectedTests(self.get_internal_shared_config(pytest_config)["global_session_id"])
self.debugger_listening = False
# We might or might not be running inside an xdist worker
self._is_running_on_worker = False
self._is_running_on_worker = not is_main_pytest_runner(pytest_config)

def get_config_val(self, config_key: PluginConfigKey):
def get_global_config_val(self, config_key: PluginConfigKey):
"""
Wrapper function just to add some extra type-safety around dynamic config keys
"""
return self.pytest_config.getini(config_key)

def get_internal_shared_config(
self, pytest_obj: Union[pytest.Session, pytest.Config, pytest.FixtureRequest]
) -> InternalSessionConfig:
"""
Utility function to get shared config values, because it can be a little tricky to know
where to retrieve them from (for main vs worker)
"""
config = pytest_obj if isinstance(pytest_obj, pytest.Config) else pytest_obj.config
# If we are on the main runner, we can just directly access
if is_main_pytest_runner(config):
session_config = cast(InternalSessionConfigDataClass, config)
return {
"temp_shared_session_dir_path": session_config.temp_shared_session_dir_path,
"global_session_id": session_config.global_session_id,
}
# If we are on a worker, we can retrieve the shared config values via the `workerinput` property
worker_input = cast(WorkerConfigInstance, config).workerinput
return worker_input

@property
def auto_debug(self) -> bool:
# Disable if CI is detected
if os.getenv("CI", "").lower() == "true":
return False
return bool(self.get_config_val("auto_debug")) or bool(os.getenv(f"{PACKAGE_NAME}_AUTO_DEBUG", ""))
return bool(self.get_global_config_val("auto_debug")) or bool(os.getenv(f"{PACKAGE_NAME}_AUTO_DEBUG", ""))

@property
def auto_debug_wait_for_connect(self) -> bool:
return bool(self.get_config_val("auto_debug_wait_for_connect"))
return bool(self.get_global_config_val("auto_debug_wait_for_connect"))

@property
def mandate_requirement_markers(self) -> bool:
return bool(self.get_config_val("mandate_requirement_markers"))
return bool(self.get_global_config_val("mandate_requirement_markers"))

@property
def reporting_config(self) -> Optional[PluginReportingConfiguration]:
csv_export_path = self.get_config_val("reporting.csv_export_path")
csv_export_path = self.get_global_config_val("reporting.csv_export_path")
if not isinstance(csv_export_path, str):
return None
return {
"csv_export_path": csv_export_path,
"omit_unexecuted_tests": bool(self.get_config_val("reporting.omit_unexecuted_tests")),
"omit_unexecuted_tests": bool(self.get_global_config_val("reporting.omit_unexecuted_tests")),
}

@property
Expand Down Expand Up @@ -282,7 +380,7 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes
requirements = validation_results["validated_requirements"]

doc_string: str = item.obj.__doc__ or "" # type: ignore
collected_tests[item.nodeid] = {
self.collected_tests[item.nodeid] = {
"node_id": item.nodeid,
"requirements": requirements,
"doc_string": doc_string.strip(),
Expand All @@ -294,10 +392,8 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes

@pytest.hookimpl()
def pytest_sessionstart(self, session: pytest.Session):
self._is_running_on_worker = getattr(session.config, "workerinput", None) is not None

if self._is_running_on_worker:
# Nothing to do here at the moment
if not is_main_pytest_runner(session):
self._is_running_on_worker = True
return

# Init debugpy listener on main
Expand All @@ -311,7 +407,7 @@ def pytest_collection_finish(self, session: pytest.Session):
def pytest_sessionfinish(self, session: pytest.Session, exitstatus):
if not self.reporting_config:
return
collected_test_mappings = collected_tests._get_data()
collected_test_mappings = self.collected_tests._get_data()
with open(self.reporting_config["csv_export_path"], "w") as csv_file:
# Use keys of first entry, since all entries should have same keys
fieldnames = collected_test_mappings[next(iter(collected_test_mappings))].keys()
Expand All @@ -327,4 +423,4 @@ def pytest_sessionfinish(self, session: pytest.Session, exitstatus):
def pytest_runtest_logreport(self, report: pytest.TestReport):
# Capture test outcomes and save to collection
if report.when == "call":
collected_tests.update_test_status(report.nodeid, "PASS" if report.passed else "FAIL")
self.collected_tests.update_test_status(report.nodeid, "PASS" if report.passed else "FAIL")
22 changes: 21 additions & 1 deletion django_utils_lib/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
from __future__ import annotations

import re
from typing import List, Tuple, cast
from typing import List, Tuple, Union, cast

import pytest
from typing_extensions import TypedDict
from xdist import is_xdist_worker

PytestNodeID = str
"""
A pytest node ID follows the format of `file_path::test_name`
"""


def is_main_pytest_runner(pytest_obj: Union[pytest.Config, pytest.FixtureRequest, pytest.Session]):
"""
Utility function that returns true only if we are in the main runner (not an xdist worker)
This should work in both xdist and non-xdist modes of operation.
"""
# Pytest config or worker node
if isinstance(pytest_obj, pytest.Config) or hasattr(pytest_obj, "workerinput"):
# The presence of "workerinput", on either a config or distributed node,
# indicates we are on a worker
return getattr(pytest_obj, "workerinput", None) is None

# Pytest session objects or requests
if hasattr(pytest_obj, "config"):
return is_xdist_worker(pytest_obj) is False

return False


class RequirementValidationResults(TypedDict):
valid: bool
errors: List[str]
Expand Down

0 comments on commit fb29cbf

Please sign in to comment.