From 26d94ceaf4898b63f1b160334292a5c9c0182892 Mon Sep 17 00:00:00 2001 From: Aleksandr Mezin Date: Sat, 31 Dec 2022 15:59:12 +0300 Subject: [PATCH] WorkerInteractor: implement "steal" command --- src/xdist/remote.py | 26 ++++++++++++++++++++- testing/test_remote.py | 51 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/xdist/remote.py b/src/xdist/remote.py index 6f91fed3..2e83a8dc 100644 --- a/src/xdist/remote.py +++ b/src/xdist/remote.py @@ -6,6 +6,7 @@ needs not to be installed in remote environments. """ +import contextlib import sys import os import time @@ -64,10 +65,13 @@ def __init__(self, config, channel): self.testrunuid = config.workerinput["testrunuid"] self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug) self.channel = channel - self.torun = channel.gateway.execmodel.queue.Queue() + self.torun = self._make_queue() self.nextitem_index = None config.pluginmanager.register(self) + def _make_queue(self): + return self.channel.gateway.execmodel.queue.Queue() + def sendevent(self, name, **kwargs): self.log("sending", name, kwargs) self.channel.send((name, kwargs)) @@ -112,6 +116,26 @@ def handle_command(self, command): self.torun.put(i) elif name == "shutdown": self.torun.put(self.SHUTDOWN_MARK) + elif name == "steal": + self.steal(kwargs["indices"]) + + def steal(self, indices): + indices = set(indices) + stolen = [] + + old_queue, self.torun = self.torun, self._make_queue() + + def old_queue_get_nowait_noraise(): + with contextlib.suppress(self.channel.gateway.execmodel.queue.Empty): + return old_queue.get_nowait() + + for i in iter(old_queue_get_nowait_noraise, None): + if i in indices: + stolen.append(i) + else: + self.torun.put(i) + + self.sendevent("unscheduled", indices=stolen) @pytest.hookimpl def pytest_runtestloop(self, session): diff --git a/testing/test_remote.py b/testing/test_remote.py index 8742ee5c..cb8f6b7f 100644 --- a/testing/test_remote.py +++ b/testing/test_remote.py @@ -220,6 +220,57 @@ def test_process_from_remote_error_handling( ev = worker.popevent() assert ev.name == "errordown" + def test_steal_work(self, worker: WorkerSetup, unserialize_report) -> None: + worker.pytester.makepyfile( + """ + import time + def test_func(): time.sleep(1) + def test_func2(): pass + def test_func3(): pass + def test_func4(): pass + """ + ) + worker.setup() + ev = worker.popevent("collectionfinish") + ids = ev.kwargs["ids"] + assert len(ids) == 4 + worker.sendcommand("runtests_all") + + # wait for test_func setup + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith("::test_func") + assert rep.when == "setup" + + worker.sendcommand("steal", indices=[1, 2]) + ev = worker.popevent("unscheduled") + assert ev.kwargs["indices"] == [2] + + reports = [ + ("test_func", "call"), + ("test_func", "teardown"), + ("test_func2", "setup"), + ("test_func2", "call"), + ("test_func2", "teardown"), + ] + + for func, when in reports: + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith(f"::{func}") + assert rep.when == when + + worker.sendcommand("shutdown") + + for when in ["setup", "call", "teardown"]: + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith("::test_func4") + assert rep.when == when + + ev = worker.popevent("workerfinished") + assert "workeroutput" in ev.kwargs + def test_remote_env_vars(pytester: pytest.Pytester) -> None: pytester.makepyfile(