From ff250f2b24025a49da294a6a8a140d07250c740b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 30 Jun 2022 14:57:21 +0100 Subject: [PATCH] Scheduler.reschedule() works only by accident (#6339) --- distributed/scheduler.py | 28 +++--- distributed/stealing.py | 2 +- distributed/tests/test_client.py | 14 +-- distributed/tests/test_reschedule.py | 123 +++++++++++++++++++++++++++ distributed/tests/test_scheduler.py | 41 ++------- distributed/tests/test_steal.py | 4 +- distributed/tests/test_worker.py | 18 ---- distributed/worker.py | 2 +- distributed/worker_state_machine.py | 14 ++- 9 files changed, 163 insertions(+), 83 deletions(-) create mode 100644 distributed/tests/test_reschedule.py diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3e2fd1ff5c..adb1bbc45f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2223,21 +2223,19 @@ def transition_waiting_released(self, key, stimulus_id): pdb.set_trace() raise - def transition_processing_released(self, key, stimulus_id): + def transition_processing_released(self, key: str, stimulus_id: str): try: - ts: TaskState = self.tasks[key] - dts: TaskState - recommendations: dict = {} - client_msgs: dict = {} - worker_msgs: dict = {} + ts = self.tasks[key] + recommendations = {} + worker_msgs = {} if self.validate: assert ts.processing_on assert not ts.who_has assert not ts.waiting_on - assert self.tasks[key].state == "processing" + assert ts.state == "processing" - w: str = _remove_from_processing(self, ts) + w = _remove_from_processing(self, ts) if w: worker_msgs[w] = [ { @@ -2265,7 +2263,7 @@ def transition_processing_released(self, key, stimulus_id): if self.validate: assert not ts.processing_on - return recommendations, client_msgs, worker_msgs + return recommendations, {}, worker_msgs except Exception as e: logger.exception(e) if LOG_PDB: @@ -6606,7 +6604,9 @@ async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]: transition_story = story - def reschedule(self, key=None, worker=None): + def reschedule( + self, key: str, worker: str | None = None, *, stimulus_id: str + ) -> None: """Reschedule a task Things may have shifted and this task may now be better suited to run @@ -6616,15 +6616,17 @@ def reschedule(self, key=None, worker=None): ts = self.tasks[key] except KeyError: logger.warning( - "Attempting to reschedule task {}, which was not " - "found on the scheduler. Aborting reschedule.".format(key) + f"Attempting to reschedule task {key}, which was not " + "found on the scheduler. Aborting reschedule." ) return if ts.state != "processing": return if worker and ts.processing_on.address != worker: return - self.transitions({key: "released"}, f"reschedule-{time()}") + # transition_processing_released will immediately suggest an additional + # transition to waiting if the task has any waiters or clients holding a future. + self.transitions({key: "released"}, stimulus_id=stimulus_id) ##################### # Utility functions # diff --git a/distributed/stealing.py b/distributed/stealing.py index e3d5d25678..b9858f4656 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -341,7 +341,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): *_log_msg, ) ) - self.scheduler.reschedule(key) + self.scheduler.reschedule(key, stimulus_id=stimulus_id) # Victim had already started execution elif state in _WORKER_STATE_REJECT: self.log(("already-computing", *_log_msg)) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 159e5116a7..42f019bb98 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4299,11 +4299,12 @@ async def test_retire_many_workers(c, s, *workers): @gen_cluster( client=True, nthreads=[("127.0.0.1", 3)] * 2, - config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, + config={ + "distributed.scheduler.work-stealing": False, + "distributed.scheduler.default-task-durations": {"f": "10ms"}, + }, ) async def test_weight_occupancy_against_data_movement(c, s, a, b): - await s.extensions["stealing"].stop() - def f(x, y=0, z=0): sleep(0.01) return x @@ -4322,11 +4323,12 @@ def f(x, y=0, z=0): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)], - config={"distributed.scheduler.default-task-durations": {"f": "10ms"}}, + config={ + "distributed.scheduler.work-stealing": False, + "distributed.scheduler.default-task-durations": {"f": "10ms"}, + }, ) async def test_distribute_tasks_by_nthreads(c, s, a, b): - await s.extensions["stealing"].stop() - def f(x, y=0): sleep(0.01) return x diff --git a/distributed/tests/test_reschedule.py b/distributed/tests/test_reschedule.py new file mode 100644 index 0000000000..c2dc71cbeb --- /dev/null +++ b/distributed/tests/test_reschedule.py @@ -0,0 +1,123 @@ +"""Tests for tasks raising the Reschedule exception and Scheduler.reschedule(). + +Note that this functionality is also used by work stealing; +see test_steal.py for additional tests. +""" +from __future__ import annotations + +import asyncio +from time import sleep + +import pytest + +from distributed import Event, Reschedule, get_worker, secede, wait +from distributed.utils_test import captured_logger, gen_cluster, slowinc +from distributed.worker_state_machine import ( + ComputeTaskEvent, + FreeKeysEvent, + RescheduleEvent, + SecedeEvent, +) + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.work-stealing": False}, +) +async def test_scheduler_reschedule(c, s, a, b): + xs = c.map(slowinc, range(100), key="x", delay=0.1) + while not a.state.tasks or not b.state.tasks: + await asyncio.sleep(0.01) + assert len(a.state.tasks) == len(b.state.tasks) == 50 + + ys = c.map(slowinc, range(100), key="y", delay=0.1, workers=[a.address]) + while len(a.state.tasks) != 150: + await asyncio.sleep(0.01) + + # Reschedule the 50 xs that are processing on a + for x in xs: + if s.tasks[x.key].processing_on is s.workers[a.address]: + s.reschedule(x.key, stimulus_id="test") + + # Wait for at least some of the 50 xs that had been scheduled on a to move to b. + # This happens because you have 100 ys processing on a and 50 xs processing on b, + # so the scheduler will prefer b for the rescheduled tasks to obtain more equal + # balancing. + while len(a.state.tasks) == 150 or len(b.state.tasks) <= 50: + await asyncio.sleep(0.01) + + +@gen_cluster() +async def test_scheduler_reschedule_warns(s, a, b): + with captured_logger("distributed.scheduler") as sched: + s.reschedule(key="__this-key-does-not-exist__", stimulus_id="test") + + assert "not found on the scheduler" in sched.getvalue() + assert "Aborting reschedule" in sched.getvalue() + + +@pytest.mark.parametrize("long_running", [False, True]) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 2, + config={"distributed.scheduler.work-stealing": False}, +) +async def test_raise_reschedule(c, s, a, b, long_running): + """A task raises Reschedule()""" + a_address = a.address + + def f(x): + if long_running: + secede() + sleep(0.1) + if get_worker().address == a_address: + raise Reschedule() + + futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"]) + futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address]) + await wait(futures) + assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log) + assert all(f.key in b.data for f in futures) + + +@pytest.mark.parametrize("long_running", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_cancelled_reschedule(c, s, a, long_running): + """A task raises Reschedule(), but the future was released by the client""" + ev1 = Event() + ev2 = Event() + + def f(ev1, ev2): + if long_running: + secede() + ev1.set() + ev2.wait() + raise Reschedule() + + x = c.submit(f, ev1, ev2, key="x") + await ev1.wait() + x.release() + while "x" in s.tasks: + await asyncio.sleep(0.01) + + await ev2.set() + while "x" in a.state.tasks: + await asyncio.sleep(0.01) + + +@pytest.mark.parametrize("long_running", [False, True]) +def test_cancelled_reschedule_worker_state(ws, long_running): + """Same as test_cancelled_reschedule""" + + ws.handle_stimulus(ComputeTaskEvent.dummy(key="x", stimulus_id="s1")) + if long_running: + ws.handle_stimulus(SecedeEvent(key="x", compute_duration=1.0, stimulus_id="s2")) + + instructions = ws.handle_stimulus( + FreeKeysEvent(keys=["x"], stimulus_id="s3"), + RescheduleEvent(key="x", stimulus_id="s4"), + ) + # There's no RescheduleMsg and the task has been forgotten + assert not instructions + assert not ws.tasks diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index a8fa1c4e5d..b90a7689e5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1132,9 +1132,12 @@ async def test_balance_many_workers(c, s, *workers): @nodebug -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 30, + config={"distributed.scheduler.work-stealing": False}, +) async def test_balance_many_workers_2(c, s, *workers): - await s.extensions["stealing"].stop() futures = c.map(slowinc, range(90), delay=0.2) await wait(futures) assert {len(w.has_what) for w in s.workers.values()} == {3} @@ -1513,35 +1516,6 @@ async def test_log_tasks_during_restart(c, s, a, b): assert "exit" in str(s.events) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -async def test_reschedule(c, s, a, b): - await c.submit(slowinc, -1, delay=0.1) # learn cost - x = c.map(slowinc, range(4), delay=0.1) - - # add much more work onto worker a - futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address) - - while len(s.tasks) < len(x) + len(futures): - await asyncio.sleep(0.001) - - for future in x: - s.reschedule(key=future.key) - - # Worker b gets more of the original tasks - await wait(x) - assert sum(future.key in b.data for future in x) >= 3 - assert sum(future.key in a.data for future in x) <= 1 - - -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -async def test_reschedule_warns(c, s, a, b): - with captured_logger(logging.getLogger("distributed.scheduler")) as sched: - s.reschedule(key="__this-key-does-not-exist__") - - assert "not found on the scheduler" in sched.getvalue() - assert "Aborting reschedule" in sched.getvalue() - - @gen_cluster(client=True) async def test_get_task_status(c, s, a, b): future = c.submit(inc, 1) @@ -3342,12 +3316,11 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers): @gen_cluster(client=True, nthreads=[("", 1)] * 2) async def test_set_restrictions(c, s, a, b): - - f = c.submit(inc, 1, workers=[b.address]) + f = c.submit(inc, 1, key="f", workers=[b.address]) await f s.set_restrictions(worker={f.key: a.address}) assert s.tasks[f.key].worker_restrictions == {a.address} - s.reschedule(f) + await b.close() await f diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 734440d3b4..98a6b48e6b 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -1104,7 +1104,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers): steal.move_task_request(victim_ts, wsA, wsB) - s.reschedule(victim_key) + s.reschedule(victim_key, stimulus_id="test") await c.gather(futs1) del futs1 @@ -1238,7 +1238,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers): steal.move_task_request(victim_ts, wsA, wsB) s.set_restrictions(worker={victim_key: [wsB.address]}) - s.reschedule(victim_key) + s.reschedule(victim_key, stimulus_id="test") assert wsB == victim_ts.processing_on # move_task_request is not responsible for respecting worker restrictions steal.move_task_request(victim_ts, wsB, wsC) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 41fb58f588..676ad41bd4 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -32,7 +32,6 @@ Client, Event, Nanny, - Reschedule, default_client, get_client, get_worker, @@ -1181,23 +1180,6 @@ def some_name(): assert result.startswith("some_name") -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) -async def test_reschedule(c, s, a, b): - await s.extensions["stealing"].stop() - a_address = a.address - - def f(x): - sleep(0.1) - if get_worker().address == a_address: - raise Reschedule() - - futures = c.map(f, range(4)) - futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address) - await wait(futures) - - assert all(f.key in b.data for f in futures) - - @gen_cluster(nthreads=[]) async def test_deque_handler(s): from distributed.worker import logger diff --git a/distributed/worker.py b/distributed/worker.py index 0b61d147a0..577e75775e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2236,7 +2236,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No ) if isinstance(result["actual-exception"], Reschedule): - return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id) + return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}") logger.warning( "Compute Failed\n" diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index cec51d8c54..c03ddef49c 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1725,9 +1725,7 @@ def _transition_waiting_constrained( def _transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: - recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id) - return recs, [smsg] + return {ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)] def _transition_executing_rescheduled( self, ts: TaskState, *, stimulus_id: str @@ -1737,10 +1735,7 @@ def _transition_executing_rescheduled( self.executing.discard(ts) return merge_recs_instructions( - ( - {ts: "released"}, - [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)], - ), + ({ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]), self._ensure_computing(), ) @@ -1987,7 +1982,7 @@ def _transition_executing_released( # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" ts.done = False - return self._ensure_computing() + return {}, [] def _transition_long_running_memory( self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str @@ -2201,6 +2196,7 @@ def _transition_released_forgotten( ("cancelled", "missing"): _transition_cancelled_released, ("cancelled", "waiting"): _transition_cancelled_waiting, ("cancelled", "forgotten"): _transition_cancelled_forgotten, + ("cancelled", "rescheduled"): _transition_cancelled_released, ("cancelled", "memory"): _transition_cancelled_memory, ("cancelled", "error"): _transition_cancelled_error, ("resumed", "memory"): _transition_generic_memory, @@ -2821,6 +2817,8 @@ def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: # without going through cancelled ts = self.tasks.get(ev.key) assert ts, self.story(ev.key) + + ts.done = True return {ts: "rescheduled"}, [] @_handle_event.register