diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index ef6fc6ce88..d7bcf1746a 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -615,7 +615,7 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task): assert ws.data["x"] == 123 -@pytest.mark.xfail(reason="distributed#6565, distributed#6689") +@pytest.mark.xfail(reason="distributed#6689") def test_workerstate_executing_failure_to_fetch(ws_with_running_task): """Test state loops: diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 85b1bd88aa..6ba7500895 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1077,7 +1077,6 @@ def test_running_task_in_all_running_tasks(ws_with_running_task): assert ts in ws.all_running_tasks -@pytest.mark.xfail(reason="distributed#6565, distributed#6692") @pytest.mark.parametrize( "done_ev_cls,done_status", [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")], @@ -1094,10 +1093,16 @@ def test_done_task_not_in_all_running_tasks( assert ts not in ws.all_running_tasks -@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692") @pytest.mark.parametrize( "done_ev_cls,done_status", - [(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")], + [ + (ExecuteSuccessEvent, "memory"), + pytest.param( + ExecuteFailureEvent, + "error", + marks=pytest.mark.xfail(reason="distributed#6689"), + ), + ], ) def test_done_resumed_task_not_in_all_running_tasks( ws_with_running_task, done_ev_cls, done_status diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 51b2707f56..4dbbbdf060 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1099,10 +1099,10 @@ class WorkerState: #: See also :meth:`executing_count` and :attr:`long_runing`. executing: set[TaskState] - #: Set of keys of tasks that are currently running and have called + #: Set of tasks that are currently running and have called #: :func:`~distributed.secede`. #: These tasks do not appear in the :attr:`executing` set. - long_running: set[str] + long_running: set[TaskState] #: A number of tasks that this worker has run in its lifetime. #: See also :meth:`executing_count`. @@ -1241,7 +1241,7 @@ def all_running_tasks(self) -> set[TaskState]: - ``ts.status == "resumed" and ts._previous in ("executing", "long-running")`` """ # Note: cancelled and resumed tasks are still in either of these sets - return self.executing | {self.tasks[key] for key in self.long_running} + return self.executing | self.long_running @property def in_flight_tasks_count(self) -> int: @@ -1342,6 +1342,7 @@ def _purge_state(self, ts: TaskState) -> None: ts.done = False self.executing.discard(ts) + self.long_running.discard(ts) self.in_flight_tasks.discard(ts) def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: @@ -1808,6 +1809,7 @@ def _transition_executing_rescheduled( self._release_resources(ts) self.executing.discard(ts) + self.long_running.discard(ts) return merge_recs_instructions( ({}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]), @@ -1845,7 +1847,7 @@ def _transition_cancelled_error( *, stimulus_id: str, ) -> RecsInstrs: - assert ts._previous == "executing" or ts.key in self.long_running + assert ts._previous in ("executing", "long-running") recs, instructions = self._transition_executing_error( ts, exception, @@ -1904,6 +1906,7 @@ def _transition_executing_error( ) -> RecsInstrs: self._release_resources(ts) self.executing.discard(ts) + self.long_running.discard(ts) return merge_recs_instructions( self._transition_generic_error( @@ -1956,7 +1959,7 @@ def _transition_from_resumed( assert finish != "memory" next_state = ts._next assert next_state in {"waiting", "fetch"}, next_state - assert ts._previous in {"executing", "flight"}, ts._previous + assert ts._previous in {"executing", "long-running", "flight"}, ts._previous if next_state != finish: recs, instructions = self._transition_generic_released( @@ -2012,7 +2015,7 @@ def _transition_cancelled_fetch( ts.state = ts._previous return {}, [] else: - assert ts._previous in {"executing", "long-running"} + assert ts._previous in ("executing", "long-running") ts.state = "resumed" ts._next = "fetch" return {}, [] @@ -2045,6 +2048,7 @@ def _transition_cancelled_released( if not ts.done: return {}, [] self.executing.discard(ts) + self.long_running.discard(ts) self.in_flight_tasks.discard(ts) self._release_resources(ts) @@ -2060,12 +2064,6 @@ def _transition_executing_released( ts.done = False return {}, [] - def _transition_long_running_memory( - self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str - ) -> RecsInstrs: - self.executed_count += 1 - return self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id) - def _transition_generic_memory( self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: @@ -2076,6 +2074,7 @@ def _transition_generic_memory( self._release_resources(ts) self.executing.discard(ts) + self.long_running.discard(ts) self.in_flight_tasks.discard(ts) ts.coming_from = None @@ -2098,11 +2097,12 @@ def _transition_executing_memory( self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str ) -> RecsInstrs: if self.validate: - assert ts.state == "executing" or ts.key in self.long_running + assert ts.state in ("executing", "long-running") assert not ts.waiting_for_data assert ts.key not in self.ready self.executing.discard(ts) + self.long_running.discard(ts) self.executed_count += 1 return merge_recs_instructions( self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id), @@ -2202,7 +2202,7 @@ def _transition_executing_long_running( ) -> RecsInstrs: ts.state = "long-running" self.executing.discard(ts) - self.long_running.add(ts.key) + self.long_running.add(ts) smsg = LongRunningMsg( key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id @@ -2294,8 +2294,8 @@ def _transition_released_forgotten( ("flight", "memory"): _transition_flight_memory, ("flight", "missing"): _transition_flight_missing, ("flight", "released"): _transition_flight_released, - ("long-running", "error"): _transition_generic_error, - ("long-running", "memory"): _transition_long_running_memory, + ("long-running", "error"): _transition_executing_error, + ("long-running", "memory"): _transition_executing_memory, ("long-running", "rescheduled"): _transition_executing_rescheduled, ("long-running", "released"): _transition_executing_released, ("memory", "released"): _transition_memory_released, @@ -2991,7 +2991,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: for w, tss in self.data_needed.items() }, "executing": {ts.key for ts in self.executing}, - "long_running": self.long_running, + "long_running": {ts.key for ts in self.long_running}, "in_flight_tasks": {ts.key for ts in self.in_flight_tasks}, "in_flight_workers": self.in_flight_workers, "busy_workers": self.busy_workers, @@ -3015,7 +3015,14 @@ def _validate_task_memory(self, ts: TaskState) -> None: assert ts.state == "memory" def _validate_task_executing(self, ts: TaskState) -> None: - assert ts.state == "executing" + if ts.state == "executing": + assert ts in self.executing + assert ts not in self.long_running + else: + assert ts.state == "long-running" + assert ts not in self.executing + assert ts in self.long_running + assert ts.run_spec is not None assert ts.key not in self.data assert not ts.waiting_for_data @@ -3072,7 +3079,7 @@ def _validate_task_cancelled(self, ts: TaskState) -> None: def _validate_task_resumed(self, ts: TaskState) -> None: assert ts.key not in self.data - assert ts._next + assert ts._next in {"fetch", "waiting"} assert ts._previous in {"long-running", "executing", "flight"} def _validate_task_released(self, ts: TaskState) -> None: @@ -3113,7 +3120,7 @@ def validate_task(self, ts: TaskState) -> None: self._validate_task_resumed(ts) elif ts.state == "ready": self._validate_task_ready(ts) - elif ts.state == "executing": + elif ts.state in ("executing", "long-running"): self._validate_task_executing(ts) elif ts.state == "flight": self._validate_task_flight(ts) @@ -3164,6 +3171,16 @@ def validate_state(self) -> None: assert ts.state == "fetch" assert worker in ts.who_has + # FIXME https://github.com/dask/distributed/issues/6689 + # for ts in self.executing: + # assert ts.state == "executing" or ( + # ts.state in ("cancelled", "resumed") and ts._previous == "executing" + # ), self.story(ts) + # for ts in self.long_running: + # assert ts.state == "long-running" or ( + # ts.state in ("cancelled", "resumed") and ts._previous == "long-running" + # ), self.story(ts) + # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__. for ts in chain( @@ -3171,6 +3188,7 @@ def validate_state(self) -> None: self.missing_dep_flight, self.in_flight_tasks, self.executing, + self.long_running, ): assert self.tasks[ts.key] is ts