diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 8ffb215025f..dd30424a671 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5066,7 +5066,7 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: self.transitions({key: "released"}, stimulus_id) def handle_long_running( - self, key: str, worker: str, compute_duration: float, stimulus_id: str + self, key: str, worker: str, compute_duration: float | None, stimulus_id: str ) -> None: """A task has seceded from the thread pool @@ -5086,11 +5086,12 @@ def handle_long_running( logger.debug("Received long-running signal from duplicate task. Ignoring.") return - old_duration = ts.prefix.duration_average - if old_duration < 0: - ts.prefix.duration_average = compute_duration - else: - ts.prefix.duration_average = (old_duration + compute_duration) / 2 + if compute_duration is not None: + old_duration = ts.prefix.duration_average + if old_duration < 0: + ts.prefix.duration_average = compute_duration + else: + ts.prefix.duration_average = (old_duration + compute_duration) / 2 occ = ws.processing[ts] ws.occupancy -= occ diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index f1a04576413..3336fc8481d 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -30,7 +30,9 @@ GatherDepFailureEvent, GatherDepNetworkFailureEvent, GatherDepSuccessEvent, + LongRunningMsg, RescheduleEvent, + SecedeEvent, TaskFinishedMsg, UpdateDataEvent, ) @@ -640,7 +642,12 @@ def test_workerstate_executing_to_executing(ws_with_running_task): FreeKeysEvent(keys=["x"], stimulus_id="s1"), ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"), ) - assert not instructions + if prev_state == "executing": + assert not instructions + else: + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2") + ] assert ws.tasks["x"] is ts assert ts.state == prev_state @@ -821,7 +828,12 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task): FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"), ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"), ) - assert not instructions + if prev_state == "executing": + assert not instructions + else: + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4") + ] assert ws.tasks["x"].state == prev_state @@ -946,3 +958,102 @@ def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls): ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5")) assert "y" not in ws.tasks assert ws.tasks["x"].state == "memory" + + +@pytest.mark.parametrize("resume_to_fetch", [False, True]) +@pytest.mark.parametrize("resume_to_executing", [False, True]) +@pytest.mark.parametrize( + "done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent, RescheduleEvent] +) +def test_secede_cancelled_or_resumed_workerstate( + ws, resume_to_fetch, resume_to_executing, done_ev_cls +): + """Test what happens when a cancelled or resumed(fetch) task calls secede(). + See also test_secede_cancelled_or_resumed_scheduler + """ + ws2 = "127.0.0.1:2" + ws.handle_stimulus( + ComputeTaskEvent.dummy("x", stimulus_id="s1"), + FreeKeysEvent(keys=["x"], stimulus_id="s2"), + ) + if resume_to_fetch: + ws.handle_stimulus( + ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"), + ) + ts = ws.tasks["x"] + assert ts.previous == "executing" + assert ts in ws.executing + assert ts not in ws.long_running + + instructions = ws.handle_stimulus( + SecedeEvent(key="x", compute_duration=1, stimulus_id="s4") + ) + assert not instructions # Do not send RescheduleMsg + assert ts.previous == "long-running" + assert ts not in ws.executing + assert ts in ws.long_running + + if resume_to_executing: + instructions = ws.handle_stimulus( + FreeKeysEvent(keys=["y"], stimulus_id="s5"), + ComputeTaskEvent.dummy("x", stimulus_id="s6"), + ) + # Inform the scheduler of the SecedeEvent that happened in the past + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6") + ] + assert ts.state == "long-running" + assert ts not in ws.executing + assert ts in ws.long_running + + ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s7")) + assert ts not in ws.executing + assert ts not in ws.long_running + + +@gen_cluster(client=True, nthreads=[("", 1)], timeout=2) +async def test_secede_cancelled_or_resumed_scheduler(c, s, a): + """Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction + with the scheduler + """ + ws = s.workers[a.address] + ev1 = Event() + ev2 = Event() + ev3 = Event() + ev4 = Event() + + def f(ev1, ev2, ev3, ev4): + ev1.set() + ev2.wait() + distributed.secede() + ev3.set() + ev4.wait() + return 123 + + x = c.submit(f, ev1, ev2, ev3, ev4, key="x") + await ev1.wait() + ts = a.state.tasks["x"] + assert ts.state == "executing" + assert sum(ws.processing.values()) > 0 + + x.release() + await wait_for_state("x", "cancelled", a) + assert not ws.processing + + await ev2.set() + await ev3.wait() + assert ts.previous == "long-running" + assert not ws.processing + + x = c.submit(inc, 1, key="x") + await wait_for_state("x", "long-running", a) + + # Test that the scheduler receives a delayed {op: long-running} + assert ws.processing + while sum(ws.processing.values()): + await asyncio.sleep(0.1) + assert ws.processing + + await ev4.set() + assert await x == 123 + assert not ws.processing diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index bf70aa30d6b..405f106cdda 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -17,6 +17,7 @@ ExecuteFailureEvent, ExecuteSuccessEvent, FreeKeysEvent, + LongRunningMsg, RescheduleEvent, TaskFinishedMsg, ) @@ -565,6 +566,8 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls): """ ws = ws_with_running_task assert ws.available_resources == {"R": 0} + ts = ws.tasks["x"] + prev_state = ts.state ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1")) assert ws.available_resources == {"R": 0} @@ -572,7 +575,12 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls): instructions = ws.handle_stimulus( ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4}) ) - assert not instructions + if prev_state == "long-running": + assert instructions == [ + LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2") + ] + else: + assert not instructions assert ws.available_resources == {"R": 0} ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s3")) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 431c8d9a8c9..cbaefa7e72a 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -517,7 +517,7 @@ class LongRunningMsg(SendMessageToScheduler): __slots__ = ("key", "compute_duration") key: str - compute_duration: float + compute_duration: float | None @dataclass @@ -2077,24 +2077,39 @@ def _transition_resumed_waiting( See also -------- _transition_cancelled_fetch + _transition_cancelled_or_resumed_long_running _transition_cancelled_waiting _transition_resumed_fetch """ # None of the exit events of execute or gather_dep recommend a transition to # waiting assert not ts.done - if ts.previous in ("executing", "long-running"): + if ts.previous == "executing": assert ts.next == "fetch" # We're back where we started. We should forget about the entire # cancellation attempt - ts.state = ts.previous + ts.state = "executing" ts.next = None ts.previous = None - elif self.validate: + return {}, [] + + elif ts.previous == "long-running": + assert ts.next == "fetch" + # Same as executing, and in addition send the LongRunningMsg in arrears + # Note that, if the task seceded before it was cancelled, this will cause + # the message to be sent twice. + ts.state = "long-running" + ts.next = None + ts.previous = None + smsg = LongRunningMsg( + key=ts.key, compute_duration=None, stimulus_id=stimulus_id + ) + return {}, [smsg] + + else: assert ts.previous == "flight" assert ts.next == "waiting" - - return {}, [] + return {}, [] def _transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str @@ -2131,17 +2146,29 @@ def _transition_cancelled_waiting( See also -------- _transition_cancelled_fetch + _transition_cancelled_or_resumed_long_running _transition_resumed_fetch _transition_resumed_waiting """ # None of the exit events of gather_dep or execute recommend a transition to # waiting assert not ts.done - if ts.previous in ("executing", "long-running"): + if ts.previous == "executing": # Forget the task was cancelled to begin with - ts.state = ts.previous + ts.state = "executing" ts.previous = None return {}, [] + elif ts.previous == "long-running": + # Forget the task was cancelled to begin with, and inform the scheduler + # in arrears that it has seceded. + # Note that, if the task seceded before it was cancelled, this will cause + # the message to be sent twice. + ts.state = "long-running" + ts.previous = None + smsg = LongRunningMsg( + key=ts.key, compute_duration=None, stimulus_id=stimulus_id + ) + return {}, [smsg] else: assert ts.previous == "flight" ts.state = "resumed" @@ -2234,6 +2261,11 @@ def _transition_flight_released( def _transition_executing_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str ) -> RecsInstrs: + """ + See also + -------- + _transition_cancelled_or_resumed_long_running + """ ts.state = "long-running" self.executing.discard(ts) self.long_running.add(ts) @@ -2246,6 +2278,34 @@ def _transition_executing_long_running( self._ensure_computing(), ) + def _transition_cancelled_or_resumed_long_running( + self, ts: TaskState, compute_duration: float, *, stimulus_id: str + ) -> RecsInstrs: + """Handles transitions: + + - cancelled(executing) -> long-running + - cancelled(long-running) -> long-running (user called secede() twice) + - resumed(executing->fetch) -> long-running + - resumed(long-running->fetch) -> long-running (user called secede() twice) + + Unlike in the executing->long_running transition, do not send LongRunningMsg. + From the scheduler's perspective, this task no longer exists (cancelled) or is + in memory on another worker (resumed). So it shouldn't hear about it. + Instead, we're going to send the LongRunningMsg when and if the task + transitions back to waiting. + + See also + -------- + _transition_executing_long_running + _transition_cancelled_waiting + _transition_resumed_waiting + """ + assert ts.previous in ("executing", "long-running") + ts.previous = "long-running" + self.executing.discard(ts) + self.long_running.add(ts) + return self._ensure_computing() + def _transition_executing_memory( self, ts: TaskState, value: object, *, stimulus_id: str ) -> RecsInstrs: @@ -2352,6 +2412,7 @@ def _transition_released_forgotten( ] = { ("cancelled", "error"): _transition_cancelled_released, ("cancelled", "fetch"): _transition_cancelled_fetch, + ("cancelled", "long-running"): _transition_cancelled_or_resumed_long_running, ("cancelled", "memory"): _transition_cancelled_released, ("cancelled", "missing"): _transition_cancelled_released, ("cancelled", "released"): _transition_cancelled_released, @@ -2359,8 +2420,8 @@ def _transition_released_forgotten( ("cancelled", "waiting"): _transition_cancelled_waiting, ("resumed", "error"): _transition_resumed_error, ("resumed", "fetch"): _transition_resumed_fetch, + ("resumed", "long-running"): _transition_cancelled_or_resumed_long_running, ("resumed", "memory"): _transition_resumed_memory, - ("resumed", "missing"): _transition_resumed_missing, ("resumed", "released"): _transition_resumed_released, ("resumed", "rescheduled"): _transition_resumed_rescheduled, ("resumed", "waiting"): _transition_resumed_waiting, @@ -2898,10 +2959,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: @_handle_event.register def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: ts = self.tasks.get(ev.key) - if ts and ts.state == "executing": - return {ts: ("long-running", ev.compute_duration)}, [] - else: + if not ts: return {}, [] + return {ts: ("long-running", ev.compute_duration)}, [] @_handle_event.register def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: