Skip to content

Commit

Permalink
Revisit WorkerState.long_running set (#6697)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jul 8, 2022
1 parent d59500e commit 1b34993
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 24 deletions.
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
assert ws.data["x"] == 123


@pytest.mark.xfail(reason="distributed#6565, distributed#6689")
@pytest.mark.xfail(reason="distributed#6689")
def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
"""Test state loops:
Expand Down
11 changes: 8 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,6 @@ def test_running_task_in_all_running_tasks(ws_with_running_task):
assert ts in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
Expand All @@ -1094,10 +1093,16 @@ def test_done_task_not_in_all_running_tasks(
assert ts not in ws.all_running_tasks


@pytest.mark.xfail(reason="distributed#6565, distributed#6689, distributed#6692")
@pytest.mark.parametrize(
"done_ev_cls,done_status",
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "error")],
[
(ExecuteSuccessEvent, "memory"),
pytest.param(
ExecuteFailureEvent,
"error",
marks=pytest.mark.xfail(reason="distributed#6689"),
),
],
)
def test_done_resumed_task_not_in_all_running_tasks(
ws_with_running_task, done_ev_cls, done_status
Expand Down
58 changes: 38 additions & 20 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,10 +1099,10 @@ class WorkerState:
#: See also :meth:`executing_count` and :attr:`long_runing`.
executing: set[TaskState]

#: Set of keys of tasks that are currently running and have called
#: Set of tasks that are currently running and have called
#: :func:`~distributed.secede`.
#: These tasks do not appear in the :attr:`executing` set.
long_running: set[str]
long_running: set[TaskState]

#: A number of tasks that this worker has run in its lifetime.
#: See also :meth:`executing_count`.
Expand Down Expand Up @@ -1241,7 +1241,7 @@ def all_running_tasks(self) -> set[TaskState]:
- ``ts.status == "resumed" and ts._previous in ("executing", "long-running")``
"""
# Note: cancelled and resumed tasks are still in either of these sets
return self.executing | {self.tasks[key] for key in self.long_running}
return self.executing | self.long_running

@property
def in_flight_tasks_count(self) -> int:
Expand Down Expand Up @@ -1342,6 +1342,7 @@ def _purge_state(self, ts: TaskState) -> None:
ts.done = False

self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)

def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
Expand Down Expand Up @@ -1808,6 +1809,7 @@ def _transition_executing_rescheduled(

self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)

return merge_recs_instructions(
({}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]),
Expand Down Expand Up @@ -1845,7 +1847,7 @@ def _transition_cancelled_error(
*,
stimulus_id: str,
) -> RecsInstrs:
assert ts._previous == "executing" or ts.key in self.long_running
assert ts._previous in ("executing", "long-running")
recs, instructions = self._transition_executing_error(
ts,
exception,
Expand Down Expand Up @@ -1904,6 +1906,7 @@ def _transition_executing_error(
) -> RecsInstrs:
self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)

return merge_recs_instructions(
self._transition_generic_error(
Expand Down Expand Up @@ -1956,7 +1959,7 @@ def _transition_from_resumed(
assert finish != "memory"
next_state = ts._next
assert next_state in {"waiting", "fetch"}, next_state
assert ts._previous in {"executing", "flight"}, ts._previous
assert ts._previous in {"executing", "long-running", "flight"}, ts._previous

if next_state != finish:
recs, instructions = self._transition_generic_released(
Expand Down Expand Up @@ -2012,7 +2015,7 @@ def _transition_cancelled_fetch(
ts.state = ts._previous
return {}, []
else:
assert ts._previous in {"executing", "long-running"}
assert ts._previous in ("executing", "long-running")
ts.state = "resumed"
ts._next = "fetch"
return {}, []
Expand Down Expand Up @@ -2045,6 +2048,7 @@ def _transition_cancelled_released(
if not ts.done:
return {}, []
self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)

self._release_resources(ts)
Expand All @@ -2060,12 +2064,6 @@ def _transition_executing_released(
ts.done = False
return {}, []

def _transition_long_running_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
self.executed_count += 1
return self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id)

def _transition_generic_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
Expand All @@ -2076,6 +2074,7 @@ def _transition_generic_memory(

self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)
self.in_flight_tasks.discard(ts)
ts.coming_from = None

Expand All @@ -2098,11 +2097,12 @@ def _transition_executing_memory(
self, ts: TaskState, value: object = NO_VALUE, *, stimulus_id: str
) -> RecsInstrs:
if self.validate:
assert ts.state == "executing" or ts.key in self.long_running
assert ts.state in ("executing", "long-running")
assert not ts.waiting_for_data
assert ts.key not in self.ready

self.executing.discard(ts)
self.long_running.discard(ts)
self.executed_count += 1
return merge_recs_instructions(
self._transition_generic_memory(ts, value=value, stimulus_id=stimulus_id),
Expand Down Expand Up @@ -2202,7 +2202,7 @@ def _transition_executing_long_running(
) -> RecsInstrs:
ts.state = "long-running"
self.executing.discard(ts)
self.long_running.add(ts.key)
self.long_running.add(ts)

smsg = LongRunningMsg(
key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id
Expand Down Expand Up @@ -2294,8 +2294,8 @@ def _transition_released_forgotten(
("flight", "memory"): _transition_flight_memory,
("flight", "missing"): _transition_flight_missing,
("flight", "released"): _transition_flight_released,
("long-running", "error"): _transition_generic_error,
("long-running", "memory"): _transition_long_running_memory,
("long-running", "error"): _transition_executing_error,
("long-running", "memory"): _transition_executing_memory,
("long-running", "rescheduled"): _transition_executing_rescheduled,
("long-running", "released"): _transition_executing_released,
("memory", "released"): _transition_memory_released,
Expand Down Expand Up @@ -2991,7 +2991,7 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
for w, tss in self.data_needed.items()
},
"executing": {ts.key for ts in self.executing},
"long_running": self.long_running,
"long_running": {ts.key for ts in self.long_running},
"in_flight_tasks": {ts.key for ts in self.in_flight_tasks},
"in_flight_workers": self.in_flight_workers,
"busy_workers": self.busy_workers,
Expand All @@ -3015,7 +3015,14 @@ def _validate_task_memory(self, ts: TaskState) -> None:
assert ts.state == "memory"

def _validate_task_executing(self, ts: TaskState) -> None:
assert ts.state == "executing"
if ts.state == "executing":
assert ts in self.executing
assert ts not in self.long_running
else:
assert ts.state == "long-running"
assert ts not in self.executing
assert ts in self.long_running

assert ts.run_spec is not None
assert ts.key not in self.data
assert not ts.waiting_for_data
Expand Down Expand Up @@ -3072,7 +3079,7 @@ def _validate_task_cancelled(self, ts: TaskState) -> None:

def _validate_task_resumed(self, ts: TaskState) -> None:
assert ts.key not in self.data
assert ts._next
assert ts._next in {"fetch", "waiting"}
assert ts._previous in {"long-running", "executing", "flight"}

def _validate_task_released(self, ts: TaskState) -> None:
Expand Down Expand Up @@ -3113,7 +3120,7 @@ def validate_task(self, ts: TaskState) -> None:
self._validate_task_resumed(ts)
elif ts.state == "ready":
self._validate_task_ready(ts)
elif ts.state == "executing":
elif ts.state in ("executing", "long-running"):
self._validate_task_executing(ts)
elif ts.state == "flight":
self._validate_task_flight(ts)
Expand Down Expand Up @@ -3164,13 +3171,24 @@ def validate_state(self) -> None:
assert ts.state == "fetch"
assert worker in ts.who_has

# FIXME https://github.com/dask/distributed/issues/6689
# for ts in self.executing:
# assert ts.state == "executing" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "executing"
# ), self.story(ts)
# for ts in self.long_running:
# assert ts.state == "long-running" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "long-running"
# ), self.story(ts)

# Test that there aren't multiple TaskState objects with the same key in any
# Set[TaskState]. See note in TaskState.__hash__.
for ts in chain(
*self.data_needed.values(),
self.missing_dep_flight,
self.in_flight_tasks,
self.executing,
self.long_running,
):
assert self.tasks[ts.key] is ts

Expand Down

0 comments on commit 1b34993

Please sign in to comment.