diff --git a/changelog/858.feature b/changelog/858.feature new file mode 100644 index 00000000..72bb8ff0 --- /dev/null +++ b/changelog/858.feature @@ -0,0 +1 @@ +New ``worksteal`` scheduler, based on the idea of `work stealing `_. It's similar to ``load`` scheduler, but it should handle tests with significantly differing duration better, and, at the same time, it should provide similar or better reuse of fixtures. diff --git a/src/xdist/dsession.py b/src/xdist/dsession.py index 4cea59bc..a622b8bd 100644 --- a/src/xdist/dsession.py +++ b/src/xdist/dsession.py @@ -8,6 +8,7 @@ LoadScopeScheduling, LoadFileScheduling, LoadGroupScheduling, + WorkStealingScheduling, ) @@ -100,6 +101,7 @@ def pytest_xdist_make_scheduler(self, config, log): "loadscope": LoadScopeScheduling, "loadfile": LoadFileScheduling, "loadgroup": LoadGroupScheduling, + "worksteal": WorkStealingScheduling, } return schedulers[dist](config, log) @@ -282,6 +284,17 @@ def worker_runtest_protocol_complete(self, node, item_index, duration): """ self.sched.mark_test_complete(node, item_index, duration) + def worker_unscheduled(self, node, indices): + """ + Emitted when a node fires the 'unscheduled' event, signalling that + some tests have been removed from the worker's queue and should be + sent to some worker again. + + This should happen only in response to 'steal' command, so schedulers + not using 'steal' command don't have to implement it. + """ + self.sched.remove_pending_tests_from_node(node, indices) + def worker_collectreport(self, node, rep): """Emitted when a node calls the pytest_collectreport hook. diff --git a/src/xdist/plugin.py b/src/xdist/plugin.py index 9bbbae74..b08b8421 100644 --- a/src/xdist/plugin.py +++ b/src/xdist/plugin.py @@ -94,7 +94,15 @@ def pytest_addoption(parser): "--dist", metavar="distmode", action="store", - choices=["each", "load", "loadscope", "loadfile", "loadgroup", "no"], + choices=[ + "each", + "load", + "loadscope", + "loadfile", + "loadgroup", + "worksteal", + "no", + ], dest="dist", default="no", help=( @@ -107,6 +115,8 @@ def pytest_addoption(parser): "loadfile: load balance by sending test grouped by file" " to any available environment.\n\n" "loadgroup: like load, but sends tests marked with 'xdist_group' to the same worker.\n\n" + "worksteal: split the test suite between available environments," + " then rebalance when any worker runs out of tests.\n\n" "(default) no: run tests inprocess, don't distribute." ), ) diff --git a/src/xdist/scheduler/__init__.py b/src/xdist/scheduler/__init__.py index ab2e830f..9201cda8 100644 --- a/src/xdist/scheduler/__init__.py +++ b/src/xdist/scheduler/__init__.py @@ -3,3 +3,4 @@ from xdist.scheduler.loadfile import LoadFileScheduling # noqa from xdist.scheduler.loadscope import LoadScopeScheduling # noqa from xdist.scheduler.loadgroup import LoadGroupScheduling # noqa +from xdist.scheduler.worksteal import WorkStealingScheduling # noqa diff --git a/src/xdist/scheduler/worksteal.py b/src/xdist/scheduler/worksteal.py new file mode 100644 index 00000000..534402c0 --- /dev/null +++ b/src/xdist/scheduler/worksteal.py @@ -0,0 +1,319 @@ +from collections import namedtuple + +from _pytest.runner import CollectReport + +from xdist.remote import Producer +from xdist.workermanage import parse_spec_config +from xdist.report import report_collection_diff + + +NodePending = namedtuple("NodePending", ["node", "pending"]) + +# Every worker needs at least 2 tests in queue - the current and the next one. +MIN_PENDING = 2 + + +class WorkStealingScheduling: + """Implement work-stealing scheduling. + + Initially, tests are distributed evenly among all nodes. + + When some node completes most of its assigned tests (when only one pending + test remains), an attempt to reassign some tests to that node is made. + + Attributes: + + :numnodes: The expected number of nodes taking part. The actual + number of nodes will vary during the scheduler's lifetime as + nodes are added by the DSession as they are brought up and + removed either because of a dead node or normal shutdown. This + number is primarily used to know when the initial collection is + completed. + + :node2collection: Map of nodes and their test collection. All + collections should always be identical. + + :node2pending: Map of nodes and the indices of their pending + tests. The indices are an index into ``.pending`` (which is + identical to their own collection stored in + ``.node2collection``). + + :collection: The one collection once it is validated to be + identical between all the nodes. It is initialised to None + until ``.schedule()`` is called. + + :pending: List of indices of globally pending tests. These are + tests which have not yet been allocated to a chunk for a node + to process. + + :log: A py.log.Producer instance. + + :config: Config object, used for handling hooks. + """ + + def __init__(self, config, log=None): + self.numnodes = len(parse_spec_config(config)) + self.node2collection = {} + self.node2pending = {} + self.pending = [] + self.collection = None + if log is None: + self.log = Producer("workstealsched") + else: + self.log = log.workstealsched + self.config = config + self.steal_requested = None + + @property + def nodes(self): + """A list of all nodes in the scheduler.""" + return list(self.node2pending.keys()) + + @property + def collection_is_completed(self): + """Boolean indication initial test collection is complete. + + This is a boolean indicating all initial participating nodes + have finished collection. The required number of initial + nodes is defined by ``.numnodes``. + """ + return len(self.node2collection) >= self.numnodes + + @property + def tests_finished(self): + """Return True if all tests have been executed by the nodes.""" + if not self.collection_is_completed: + return False + if self.pending: + return False + if self.steal_requested is not None: + return False + for pending in self.node2pending.values(): + if len(pending) >= MIN_PENDING: + return False + return True + + @property + def has_pending(self): + """Return True if there are pending test items + + This indicates that collection has finished and nodes are + still processing test items, so this can be thought of as + "the scheduler is active". + """ + if self.pending: + return True + for pending in self.node2pending.values(): + if pending: + return True + return False + + def add_node(self, node): + """Add a new node to the scheduler. + + From now on the node will be allocated chunks of tests to + execute. + + Called by the ``DSession.worker_workerready`` hook when it + successfully bootstraps a new node. + """ + assert node not in self.node2pending + self.node2pending[node] = [] + + def add_node_collection(self, node, collection): + """Add the collected test items from a node + + The collection is stored in the ``.node2collection`` map. + Called by the ``DSession.worker_collectionfinish`` hook. + """ + assert node in self.node2pending + if self.collection_is_completed: + # A new node has been added later, perhaps an original one died. + # .schedule() should have + # been called by now + assert self.collection + if collection != self.collection: + other_node = next(iter(self.node2collection.keys())) + msg = report_collection_diff( + self.collection, collection, other_node.gateway.id, node.gateway.id + ) + self.log(msg) + return + self.node2collection[node] = list(collection) + + def mark_test_complete(self, node, item_index, duration=None): + """Mark test item as completed by node + + This is called by the ``DSession.worker_testreport`` hook. + """ + self.node2pending[node].remove(item_index) + self.check_schedule() + + def mark_test_pending(self, item): + self.pending.insert( + 0, + self.collection.index(item), + ) + self.check_schedule() + + def remove_pending_tests_from_node(self, node, indices): + """Node returned some test indices back in response to 'steal' command. + + This is called by ``DSession.worker_unscheduled``. + """ + assert node is self.steal_requested + self.steal_requested = None + + indices_set = set(indices) + self.node2pending[node] = [ + i for i in self.node2pending[node] if i not in indices_set + ] + self.pending.extend(indices) + self.check_schedule() + + def check_schedule(self): + """Reschedule tests/perform load balancing.""" + nodes_up = [ + NodePending(node, pending) + for node, pending in self.node2pending.items() + if not node.shutting_down + ] + + def get_idle_nodes(): + return [node for node, pending in nodes_up if len(pending) < MIN_PENDING] + + idle_nodes = get_idle_nodes() + if not idle_nodes: + return + + if self.pending: + # Distribute pending tests evenly among idle nodes + for i, node in enumerate(idle_nodes): + nodes_remaining = len(idle_nodes) - i + num_send = len(self.pending) // nodes_remaining + self._send_tests(node, num_send) + + idle_nodes = get_idle_nodes() + # No need to steal anything if all nodes have enough work to continue + if not idle_nodes: + return + + # Only one active stealing request is allowed + if self.steal_requested is not None: + return + + # Find the node that has the longest test queue + steal_from = max( + nodes_up, key=lambda node_pending: len(node_pending.pending), default=None + ) + + if steal_from is None: + num_steal = 0 + else: + # Steal half of the test queue - but keep that node running too. + # If the node has 2 or less tests queued, stealing will fail + # anyway. + max_steal = max(0, len(steal_from.pending) - MIN_PENDING) + num_steal = min(len(steal_from.pending) // 2, max_steal) + + if num_steal == 0: + # Can't get more work - shutdown idle nodes. This will force them + # to run the last test now instead of waiting for more tests. + for node in idle_nodes: + node.shutdown() + return + + steal_from.node.send_steal(steal_from.pending[-num_steal:]) + self.steal_requested = steal_from.node + + def remove_node(self, node): + """Remove a node from the scheduler + + This should be called either when the node crashed or at + shutdown time. In the former case any pending items assigned + to the node will be re-scheduled. Called by the + ``DSession.worker_workerfinished`` and + ``DSession.worker_errordown`` hooks. + + Return the item which was being executing while the node + crashed or None if the node has no more pending items. + + """ + pending = self.node2pending.pop(node) + + # If node was removed without completing its assigned tests - it crashed + if pending: + crashitem = self.collection[pending.pop(0)] + else: + crashitem = None + + self.pending.extend(pending) + + # Dead node won't respond to "steal" request + if self.steal_requested is node: + self.steal_requested = None + + self.check_schedule() + return crashitem + + def schedule(self): + """Initiate distribution of the test collection + + Initiate scheduling of the items across the nodes. If this + gets called again later it behaves the same as calling + ``.check_schedule()`` on all nodes so that newly added nodes + will start to be used. + + This is called by the ``DSession.worker_collectionfinish`` hook + if ``.collection_is_completed`` is True. + """ + assert self.collection_is_completed + + # Initial distribution already happened, reschedule on all nodes + if self.collection is not None: + self.check_schedule() + return + + if not self._check_nodes_have_same_collection(): + self.log("**Different tests collected, aborting run**") + return + + # Collections are identical, create the index of pending items. + self.collection = list(self.node2collection.values())[0] + self.pending[:] = range(len(self.collection)) + if not self.collection: + return + + self.check_schedule() + + def _send_tests(self, node, num): + tests_per_node = self.pending[:num] + if tests_per_node: + del self.pending[:num] + self.node2pending[node].extend(tests_per_node) + node.send_runtest_some(tests_per_node) + + def _check_nodes_have_same_collection(self): + """Return True if all nodes have collected the same items. + + If collections differ, this method returns False while logging + the collection differences and posting collection errors to + pytest_collectreport hook. + """ + node_collection_items = list(self.node2collection.items()) + first_node, col = node_collection_items[0] + same_collection = True + for node, collection in node_collection_items[1:]: + msg = report_collection_diff( + col, collection, first_node.gateway.id, node.gateway.id + ) + if msg: + same_collection = False + self.log(msg) + if self.config is not None: + rep = CollectReport( + node.gateway.id, "failed", longrepr=msg, result=[] + ) + self.config.hook.pytest_collectreport(report=rep) + + return same_collection diff --git a/src/xdist/workermanage.py b/src/xdist/workermanage.py index 10f681cc..fdd4109a 100644 --- a/src/xdist/workermanage.py +++ b/src/xdist/workermanage.py @@ -300,6 +300,9 @@ def send_runtest_some(self, indices): def send_runtest_all(self): self.sendcommand("runtests_all") + def send_steal(self, indices): + self.sendcommand("steal", indices=indices) + def shutdown(self): if not self._down: try: @@ -359,6 +362,8 @@ def process_from_remote(self, eventcall): # noqa too complex self.notify_inproc(eventname, node=self, ids=kwargs["ids"]) elif eventname == "runtest_protocol_complete": self.notify_inproc(eventname, node=self, **kwargs) + elif eventname == "unscheduled": + self.notify_inproc(eventname, node=self, **kwargs) elif eventname == "logwarning": self.notify_inproc( eventname, diff --git a/testing/test_dsession.py b/testing/test_dsession.py index 24ec4ae9..e2490f1d 100644 --- a/testing/test_dsession.py +++ b/testing/test_dsession.py @@ -1,6 +1,6 @@ from xdist.dsession import DSession, get_default_max_worker_restart from xdist.report import report_collection_diff -from xdist.scheduler import EachScheduling, LoadScheduling +from xdist.scheduler import EachScheduling, LoadScheduling, WorkStealingScheduling from typing import Optional import pytest @@ -17,6 +17,7 @@ def __init__(self) -> None: class MockNode: def __init__(self) -> None: self.sent = [] # type: ignore[var-annotated] + self.stolen = [] self.gateway = MockGateway() self._shutdown = False @@ -26,6 +27,9 @@ def send_runtest_some(self, indices) -> None: def send_runtest_all(self) -> None: self.sent.append("ALL") + def send_steal(self, indices) -> None: + self.stolen.extend(indices) + def shutdown(self) -> None: self._shutdown = True @@ -267,6 +271,169 @@ def pytest_collectreport(self, report): assert "Different tests were collected between" in rep.longrepr +class TestWorkStealingScheduling: + def test_ideal_case(self, pytester: pytest.Pytester) -> None: + config = pytester.parseconfig("--tx=2*popen") + sched = WorkStealingScheduling(config) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + node1, node2 = sched.nodes + collection = [f"test_workstealing.py::test_{i}" for i in range(16)] + assert not sched.collection_is_completed + sched.add_node_collection(node1, collection) + assert not sched.collection_is_completed + sched.add_node_collection(node2, collection) + assert sched.collection_is_completed + assert sched.node2collection[node1] == collection + assert sched.node2collection[node2] == collection + sched.schedule() + assert not sched.pending + assert not sched.tests_finished + assert node1.sent == list(range(0, 8)) + assert node2.sent == list(range(8, 16)) + for i in range(8): + sched.mark_test_complete(node1, node1.sent[i]) + sched.mark_test_complete(node2, node2.sent[i]) + assert sched.tests_finished + assert node1.stolen == [] + assert node2.stolen == [] + + def test_stealing(self, pytester: pytest.Pytester) -> None: + config = pytester.parseconfig("--tx=2*popen") + sched = WorkStealingScheduling(config) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + node1, node2 = sched.nodes + collection = [f"test_workstealing.py::test_{i}" for i in range(16)] + sched.add_node_collection(node1, collection) + sched.add_node_collection(node2, collection) + assert sched.collection_is_completed + sched.schedule() + assert node1.sent == list(range(0, 8)) + assert node2.sent == list(range(8, 16)) + for i in range(8): + sched.mark_test_complete(node1, node1.sent[i]) + assert node2.stolen == list(range(12, 16)) + sched.remove_pending_tests_from_node(node2, node2.stolen) + for i in range(4): + sched.mark_test_complete(node2, node2.sent[i]) + assert node1.stolen == [14, 15] + sched.remove_pending_tests_from_node(node1, node1.stolen) + sched.mark_test_complete(node1, 12) + sched.mark_test_complete(node2, 14) + assert node2.stolen == list(range(12, 16)) + assert node1.stolen == [14, 15] + assert sched.tests_finished + + def test_steal_on_add_node(self, pytester: pytest.Pytester) -> None: + node = MockNode() + config = pytester.parseconfig("--tx=popen") + sched = WorkStealingScheduling(config) + sched.add_node(node) + collection = [f"test_workstealing.py::test_{i}" for i in range(5)] + sched.add_node_collection(node, collection) + assert sched.collection_is_completed + sched.schedule() + assert not sched.pending + sched.mark_test_complete(node, 0) + node2 = MockNode() + sched.add_node(node2) + sched.add_node_collection(node2, collection) + assert sched.collection_is_completed + sched.schedule() + assert node.stolen == [3, 4] + sched.remove_pending_tests_from_node(node, node.stolen) + sched.mark_test_complete(node, 1) + sched.mark_test_complete(node2, 3) + assert sched.tests_finished + assert node2.stolen == [] + + def test_schedule_fewer_tests_than_nodes(self, pytester: pytest.Pytester) -> None: + config = pytester.parseconfig("--tx=3*popen") + sched = WorkStealingScheduling(config) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + node1, node2, node3 = sched.nodes + col = ["xyz"] * 2 + sched.add_node_collection(node1, col) + sched.add_node_collection(node2, col) + sched.add_node_collection(node3, col) + sched.schedule() + assert node1.sent == [] + assert node1.stolen == [] + assert node2.sent == [0] + assert node2.stolen == [] + assert node3.sent == [1] + assert node3.stolen == [] + assert not sched.pending + assert sched.tests_finished + + def test_schedule_fewer_than_two_tests_per_node( + self, pytester: pytest.Pytester + ) -> None: + config = pytester.parseconfig("--tx=3*popen") + sched = WorkStealingScheduling(config) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + sched.add_node(MockNode()) + node1, node2, node3 = sched.nodes + col = ["xyz"] * 5 + sched.add_node_collection(node1, col) + sched.add_node_collection(node2, col) + sched.add_node_collection(node3, col) + sched.schedule() + assert node1.sent == [0] + assert node2.sent == [1, 2] + assert node3.sent == [3, 4] + assert not sched.pending + assert not sched.tests_finished + sched.mark_test_complete(node1, node1.sent[0]) + sched.mark_test_complete(node2, node2.sent[0]) + sched.mark_test_complete(node3, node3.sent[0]) + sched.mark_test_complete(node3, node3.sent[1]) + assert sched.tests_finished + assert node1.stolen == [] + assert node2.stolen == [] + assert node3.stolen == [] + + def test_add_remove_node(self, pytester: pytest.Pytester) -> None: + node = MockNode() + config = pytester.parseconfig("--tx=popen") + sched = WorkStealingScheduling(config) + sched.add_node(node) + collection = ["test_file.py::test_func"] + sched.add_node_collection(node, collection) + assert sched.collection_is_completed + sched.schedule() + assert not sched.pending + crashitem = sched.remove_node(node) + assert crashitem == collection[0] + + def test_different_tests_collected(self, pytester: pytest.Pytester) -> None: + class CollectHook: + def __init__(self): + self.reports = [] + + def pytest_collectreport(self, report): + self.reports.append(report) + + collect_hook = CollectHook() + config = pytester.parseconfig("--tx=2*popen") + config.pluginmanager.register(collect_hook, "collect_hook") + node1 = MockNode() + node2 = MockNode() + sched = WorkStealingScheduling(config) + sched.add_node(node1) + sched.add_node(node2) + sched.add_node_collection(node1, ["a.py::test_1"]) + sched.add_node_collection(node2, ["a.py::test_2"]) + sched.schedule() + assert len(collect_hook.reports) == 1 + rep = collect_hook.reports[0] + assert "Different tests were collected between" in rep.longrepr + + class TestDistReporter: @pytest.mark.xfail def test_rsync_printing(self, pytester: pytest.Pytester, linecomp) -> None: