From 9788f12e9e558e813c667df42ce2f33c09c35758 Mon Sep 17 00:00:00 2001 From: Aleksandr Mezin Date: Wed, 30 Oct 2024 19:00:10 +0200 Subject: [PATCH] Make 'steal' command atomic (#1144) Either unschedule all requested tests, or none if it's not possible - if some of the requested tests have already been processed by the time the request arrives. It may happen if the worker runs tests faster than the controller receives and processes status updates. But in this case maybe it's just better to let the worker keep running. This is a prerequisite for group/scope support in worksteal scheduler - so they won't be broken up incorrectly. This change could break schedulers that use "steal" command. However: 1) worksteal scheduler doesn't need any adjustments. 2) I'm not aware of any external schedulers relying on this command yet. So I think it's better to keep the protocol simple, not complicate it for imaginary compatibility with some unknown and likely non-existent schedulers. Co-authored-by: Bruno Oliveira --- changelog/1144.feature | 3 ++ src/xdist/remote.py | 88 ++++++++++++++++++++++++++++-------------- testing/test_remote.py | 6 +++ 3 files changed, 69 insertions(+), 28 deletions(-) create mode 100644 changelog/1144.feature diff --git a/changelog/1144.feature b/changelog/1144.feature new file mode 100644 index 00000000..911d3d00 --- /dev/null +++ b/changelog/1144.feature @@ -0,0 +1,3 @@ +The internal `steal` command is now atomic - it unschedules either all requested tests or none. + +This is a prerequisite for group/scope support in the `worksteal` scheduler, so test groups won't be broken up incorrectly. diff --git a/src/xdist/remote.py b/src/xdist/remote.py index dd1f9883..5439f6f0 100644 --- a/src/xdist/remote.py +++ b/src/xdist/remote.py @@ -8,6 +8,7 @@ from __future__ import annotations +import collections import contextlib import enum import os @@ -15,9 +16,11 @@ import time from typing import Any from typing import Generator +from typing import Iterable from typing import Literal from typing import Sequence from typing import TypedDict +from typing import Union import warnings from _pytest.config import _prepareconfig @@ -66,7 +69,44 @@ def worker_title(title: str) -> None: class Marker(enum.Enum): SHUTDOWN = 0 - QUEUE_REPLACED = 1 + + +class TestQueue: + """A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method.""" + + Item = Union[int, Literal[Marker.SHUTDOWN]] + + def __init__(self, execmodel: execnet.gateway_base.ExecModel): + self._items: collections.deque[TestQueue.Item] = collections.deque() + self._lock = execmodel.RLock() # type: ignore[no-untyped-call] + self._has_items_event = execmodel.Event() + + def get(self) -> Item: + while True: + with self.lock() as locked_items: + if locked_items: + return locked_items.popleft() + + self._has_items_event.wait() + + def put(self, item: Item) -> None: + with self.lock() as locked_items: + locked_items.append(item) + + def replace(self, iterable: Iterable[Item]) -> None: + with self.lock(): + self._items = collections.deque(iterable) + + @contextlib.contextmanager + def lock(self) -> Generator[collections.deque[Item], None, None]: + with self._lock: + try: + yield self._items + finally: + if self._items: + self._has_items_event.set() + else: + self._has_items_event.clear() class WorkerInteractor: @@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None: self.testrunuid = workerinput["testrunuid"] self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug) self.channel = channel - self.torun = self._make_queue() + self.torun = TestQueue(self.channel.gateway.execmodel) self.nextitem_index: int | None | Literal[Marker.SHUTDOWN] = None config.pluginmanager.register(self) - def _make_queue(self) -> Any: - return self.channel.gateway.execmodel.queue.Queue() - - def _get_next_item_index(self) -> int | Literal[Marker.SHUTDOWN]: - """Gets the next item from test queue. Handles the case when the queue - is replaced concurrently in another thread. - """ - result = self.torun.get() - while result is Marker.QUEUE_REPLACED: - result = self.torun.get() - return result # type: ignore[no-any-return] - def sendevent(self, name: str, **kwargs: object) -> None: self.log("sending", name, kwargs) self.channel.send((name, kwargs)) @@ -146,30 +174,34 @@ def handle_command( self.steal(kwargs["indices"]) def steal(self, indices: Sequence[int]) -> None: - indices_set = set(indices) - stolen = [] + """ + Remove tests from the queue. - old_queue, self.torun = self.torun, self._make_queue() + Removes either all requested tests, or none, if some of these tests + are not in the queue (for example, if they were processed already). - def old_queue_get_nowait_noraise() -> int | None: - with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty): - return old_queue.get_nowait() # type: ignore[no-any-return] - return None + :param indices: indices of the tests to remove. + """ + requested_set = set(indices) + + with self.torun.lock() as locked_queue: + stolen = list(item for item in locked_queue if item in requested_set) - for i in iter(old_queue_get_nowait_noraise, None): - if i in indices_set: - stolen.append(i) + # Stealing only if all requested tests are still pending + if len(stolen) == len(requested_set): + self.torun.replace( + item for item in locked_queue if item not in requested_set + ) else: - self.torun.put(i) + stolen = [] self.sendevent("unscheduled", indices=stolen) - old_queue.put(Marker.QUEUE_REPLACED) @pytest.hookimpl def pytest_runtestloop(self, session: pytest.Session) -> bool: self.log("entering main loop") self.channel.setcallback(self.handle_command, endmarker=Marker.SHUTDOWN) - self.nextitem_index = self._get_next_item_index() + self.nextitem_index = self.torun.get() while self.nextitem_index is not Marker.SHUTDOWN: self.run_one_test() if session.shouldfail or session.shouldstop: @@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool: def run_one_test(self) -> None: assert isinstance(self.nextitem_index, int) self.item_index = self.nextitem_index - self.nextitem_index = self._get_next_item_index() + self.nextitem_index = self.torun.get() items = self.session.items item = items[self.item_index] diff --git a/testing/test_remote.py b/testing/test_remote.py index 0b0334dc..b995cc4a 100644 --- a/testing/test_remote.py +++ b/testing/test_remote.py @@ -267,6 +267,12 @@ def test_func4(): pass worker.sendcommand("steal", indices=[1, 2]) ev = worker.popevent("unscheduled") + # Cannot steal index 1 because it is completed already, so do not steal any. + assert ev.kwargs["indices"] == [] + + # Index 2 can be stolen, as it is still pending. + worker.sendcommand("steal", indices=[2]) + ev = worker.popevent("unscheduled") assert ev.kwargs["indices"] == [2] reports = [