From a6557e5849219f860e5e1210ff8afbff0c440329 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 18 Aug 2022 20:28:28 +0100 Subject: [PATCH] cancelled/resumed->long-running transitions --- distributed/scheduler.py | 13 ++--- distributed/tests/test_cancelled_state.py | 63 ++++++++++++++++++++++- distributed/worker_state_machine.py | 62 ++++++++++++++++++---- 3 files changed, 119 insertions(+), 19 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index e38447c70b9..3f758c3d9cc 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4770,7 +4770,7 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: self.transitions({key: "released"}, stimulus_id) def handle_long_running( - self, key: str, worker: str, compute_duration: float, stimulus_id: str + self, key: str, worker: str, compute_duration: float | None, stimulus_id: str ) -> None: """A task has seceded from the thread pool @@ -4790,11 +4790,12 @@ def handle_long_running( logger.debug("Received long-running signal from duplicate task. Ignoring.") return - old_duration = ts.prefix.duration_average - if old_duration < 0: - ts.prefix.duration_average = compute_duration - else: - ts.prefix.duration_average = (old_duration + compute_duration) / 2 + if compute_duration is not None: + old_duration = ts.prefix.duration_average + if old_duration < 0: + ts.prefix.duration_average = compute_duration + else: + ts.prefix.duration_average = (old_duration + compute_duration) / 2 occ = ws.processing[ts] ws.occupancy -= occ diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index de1e4ecb880..99000b07fe4 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -30,6 +30,8 @@ GatherDepFailureEvent, GatherDepNetworkFailureEvent, GatherDepSuccessEvent, + LongRunningMsg, + SecedeEvent, TaskFinishedMsg, UpdateDataEvent, ) @@ -639,7 +641,12 @@ def test_workerstate_executing_to_executing(ws_with_running_task): FreeKeysEvent(keys=["x"], stimulus_id="s1"), ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"), ) - assert not instructions + if prev_state == "executing": + assert not instructions + else: + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2") + ] assert ws.tasks["x"] is ts assert ts.state == prev_state @@ -820,7 +827,12 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task): FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"), ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"), ) - assert not instructions + if prev_state == "executing": + assert not instructions + else: + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4") + ] assert ws.tasks["x"].state == prev_state @@ -943,3 +955,50 @@ def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls): ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5")) assert "y" not in ws.tasks assert ws.tasks["x"].state == "memory" + + +@pytest.mark.parametrize("resume_to_fetch", [False, True]) +@pytest.mark.parametrize("resume_to_executing", [False, True]) +@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent]) +def test_secede_cancelled_or_resumed( + ws, resume_to_fetch, resume_to_executing, done_ev_cls +): + """Test what happens when a cancelled or resumed(fetch) task calls secede()""" + ws2 = "127.0.0.1:2" + ws.handle_stimulus( + ComputeTaskEvent.dummy("x", stimulus_id="s1"), + FreeKeysEvent(keys=["x"], stimulus_id="s2"), + ) + if resume_to_fetch: + ws.handle_stimulus( + ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"), + ) + ts = ws.tasks["x"] + assert ts.previous == "executing" + assert ts in ws.executing + assert ts not in ws.long_running + + instructions = ws.handle_stimulus( + SecedeEvent(key="x", compute_duration=1, stimulus_id="s4") + ) + assert not instructions # Do not send RescheduleMsg + assert ts.previous == "long-running" + assert ts not in ws.executing + assert ts in ws.long_running + + if resume_to_executing: + instructions = ws.handle_stimulus( + FreeKeysEvent(keys=["y"], stimulus_id="s5"), + ComputeTaskEvent.dummy("x", stimulus_id="s6"), + ) + # Inform the scheduler of the SecedeEvent that happened in the past + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6") + ] + assert ts.state == "long-running" + assert ts not in ws.executing + assert ts in ws.long_running + + ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s7")) + assert ts not in ws.executing + assert ts not in ws.long_running diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 6f6c6978a26..91d451be663 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -513,7 +513,7 @@ class LongRunningMsg(SendMessageToScheduler): __slots__ = ("key", "compute_duration") key: str - compute_duration: float + compute_duration: float | None @dataclass @@ -2024,24 +2024,39 @@ def _transition_resumed_waiting( See also -------- _transition_cancelled_fetch + _transition_cancelled_long_running _transition_cancelled_waiting _transition_resumed_fetch """ # None of the exit events of execute or gather_dep recommend a transition to # waiting assert not ts.done - if ts.previous in ("executing", "long-running"): + if ts.previous == "executing": assert ts.next == "fetch" # We're back where we started. We should forget about the entire # cancellation attempt - ts.state = ts.previous + ts.state = "executing" ts.next = None ts.previous = None - elif self.validate: + return {}, [] + + elif ts.previous == "long-running": + assert ts.next == "fetch" + # Same as executing, and in addition send the LongRunningMsg in arrears + # Note that, if the task seceded before it was cancelled, this will cause + # the message to be sent twice. + ts.state = "long-running" + ts.next = None + ts.previous = None + smsg = LongRunningMsg( + key=ts.key, compute_duration=None, stimulus_id=stimulus_id + ) + return {}, [smsg] + + else: assert ts.previous == "flight" assert ts.next == "waiting" - - return {}, [] + return {}, [] def _transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str @@ -2078,17 +2093,29 @@ def _transition_cancelled_waiting( See also -------- _transition_cancelled_fetch + _transition_cancelled_long_running _transition_resumed_fetch _transition_resumed_waiting """ # None of the exit events of gather_dep or execute recommend a transition to # waiting assert not ts.done - if ts.previous in ("executing", "long-running"): + if ts.previous == "executing": # Forget the task was cancelled to begin with - ts.state = ts.previous + ts.state = "executing" ts.previous = None return {}, [] + elif ts.previous == "long-running": + # Forget the task was cancelled to begin with, and inform the scheduler + # in arrears that it has seceded. + # Note that, if the task seceded before it was cancelled, this will cause + # the message to be sent twice. + ts.state = "long-running" + ts.previous = None + smsg = LongRunningMsg( + key=ts.key, compute_duration=None, stimulus_id=stimulus_id + ) + return {}, [smsg] else: assert ts.previous == "flight" ts.state = "resumed" @@ -2193,6 +2220,18 @@ def _transition_executing_long_running( self._ensure_computing(), ) + def _transition_cancelled_long_running( + self, ts: TaskState, compute_duration: float, *, stimulus_id: str + ) -> RecsInstrs: + """This transition also serves resumed(fetch) -> long-running""" + assert ts.previous in ("executing", "long-running") + ts.previous = "long-running" + self.executing.discard(ts) + self.long_running.add(ts) + + # Do not send LongRunningMsg + return self._ensure_computing() + def _transition_executing_memory( self, ts: TaskState, value: object, *, stimulus_id: str ) -> RecsInstrs: @@ -2299,6 +2338,7 @@ def _transition_released_forgotten( ] = { ("cancelled", "fetch"): _transition_cancelled_fetch, ("cancelled", "error"): _transition_cancelled_released, + ("cancelled", "long-running"): _transition_cancelled_long_running, ("cancelled", "memory"): _transition_cancelled_released, ("cancelled", "missing"): _transition_cancelled_released, ("cancelled", "released"): _transition_cancelled_released, @@ -2306,6 +2346,7 @@ def _transition_released_forgotten( ("cancelled", "waiting"): _transition_cancelled_waiting, ("resumed", "memory"): _transition_resumed_memory, ("resumed", "error"): _transition_resumed_error, + ("resumed", "long-running"): _transition_cancelled_long_running, ("resumed", "released"): _transition_resumed_released, ("resumed", "waiting"): _transition_resumed_waiting, ("resumed", "fetch"): _transition_resumed_fetch, @@ -2834,10 +2875,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: @_handle_event.register def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: ts = self.tasks.get(ev.key) - if ts and ts.state == "executing": - return {ts: ("long-running", ev.compute_duration)}, [] - else: + if not ts: return {}, [] + return {ts: ("long-running", ev.compute_duration)}, [] @_handle_event.register def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: