diff --git a/src/xdist/remote.py b/src/xdist/remote.py index 21e3893b..32dae34b 100644 --- a/src/xdist/remote.py +++ b/src/xdist/remote.py @@ -62,10 +62,13 @@ def __init__(self, config, channel): self.testrunuid = config.workerinput["testrunuid"] self.log = Producer(f"worker-{self.workerid}", enabled=config.option.debug) self.channel = channel - self.torun = channel.gateway.execmodel.queue.Queue() + self.torun = self._make_queue() self.nextitem_index = None config.pluginmanager.register(self) + def _make_queue(self): + return channel.gateway.execmodel.queue.Queue() + def sendevent(self, name, **kwargs): self.log("sending", name, kwargs) self.channel.send((name, kwargs)) @@ -110,6 +113,20 @@ def handle_command(self, command): self.torun.put(i) elif name == "shutdown": self.torun.put(None) + elif name == "steal": + old_queue, self.torun = self.torun, self._make_queue() + stolen = [] + steal_set = set(kwargs["indices"]) + try: + while 1: + i = old_queue.get_nowait() + if i in steal_set: + stolen.append(i) + else: + self.torun.put(i) + except self.channel.gateway.execmodel.queue.Empty: + pass + self.sendevent("unscheduled", indices=stolen) @pytest.hookimpl def pytest_runtestloop(self, session): diff --git a/testing/test_remote.py b/testing/test_remote.py index 8742ee5c..cb8f6b7f 100644 --- a/testing/test_remote.py +++ b/testing/test_remote.py @@ -220,6 +220,57 @@ def test_process_from_remote_error_handling( ev = worker.popevent() assert ev.name == "errordown" + def test_steal_work(self, worker: WorkerSetup, unserialize_report) -> None: + worker.pytester.makepyfile( + """ + import time + def test_func(): time.sleep(1) + def test_func2(): pass + def test_func3(): pass + def test_func4(): pass + """ + ) + worker.setup() + ev = worker.popevent("collectionfinish") + ids = ev.kwargs["ids"] + assert len(ids) == 4 + worker.sendcommand("runtests_all") + + # wait for test_func setup + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith("::test_func") + assert rep.when == "setup" + + worker.sendcommand("steal", indices=[1, 2]) + ev = worker.popevent("unscheduled") + assert ev.kwargs["indices"] == [2] + + reports = [ + ("test_func", "call"), + ("test_func", "teardown"), + ("test_func2", "setup"), + ("test_func2", "call"), + ("test_func2", "teardown"), + ] + + for func, when in reports: + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith(f"::{func}") + assert rep.when == when + + worker.sendcommand("shutdown") + + for when in ["setup", "call", "teardown"]: + ev = worker.popevent("testreport") + rep = unserialize_report(ev.kwargs["data"]) + assert rep.nodeid.endswith("::test_func4") + assert rep.when == when + + ev = worker.popevent("workerfinished") + assert "workeroutput" in ev.kwargs + def test_remote_env_vars(pytester: pytest.Pytester) -> None: pytester.makepyfile(