Skip to content

Commit

Permalink
cancelled/resumed->long-running transitions
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Aug 18, 2022
1 parent 58a5a3c commit a6557e5
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 19 deletions.
13 changes: 7 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4770,7 +4770,7 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(
self, key: str, worker: str, compute_duration: float, stimulus_id: str
self, key: str, worker: str, compute_duration: float | None, stimulus_id: str
) -> None:
"""A task has seceded from the thread pool
Expand All @@ -4790,11 +4790,12 @@ def handle_long_running(
logger.debug("Received long-running signal from duplicate task. Ignoring.")
return

old_duration = ts.prefix.duration_average
if old_duration < 0:
ts.prefix.duration_average = compute_duration
else:
ts.prefix.duration_average = (old_duration + compute_duration) / 2
if compute_duration is not None:
old_duration = ts.prefix.duration_average
if old_duration < 0:
ts.prefix.duration_average = compute_duration
else:
ts.prefix.duration_average = (old_duration + compute_duration) / 2

occ = ws.processing[ts]
ws.occupancy -= occ
Expand Down
63 changes: 61 additions & 2 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
LongRunningMsg,
SecedeEvent,
TaskFinishedMsg,
UpdateDataEvent,
)
Expand Down Expand Up @@ -639,7 +641,12 @@ def test_workerstate_executing_to_executing(ws_with_running_task):
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"),
)
assert not instructions
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
]
assert ws.tasks["x"] is ts
assert ts.state == prev_state

Expand Down Expand Up @@ -820,7 +827,12 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"),
)
assert not instructions
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
]
assert ws.tasks["x"].state == prev_state


Expand Down Expand Up @@ -943,3 +955,50 @@ def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls):
ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5"))
assert "y" not in ws.tasks
assert ws.tasks["x"].state == "memory"


@pytest.mark.parametrize("resume_to_fetch", [False, True])
@pytest.mark.parametrize("resume_to_executing", [False, True])
@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent])
def test_secede_cancelled_or_resumed(
ws, resume_to_fetch, resume_to_executing, done_ev_cls
):
"""Test what happens when a cancelled or resumed(fetch) task calls secede()"""
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s1"),
FreeKeysEvent(keys=["x"], stimulus_id="s2"),
)
if resume_to_fetch:
ws.handle_stimulus(
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"),
)
ts = ws.tasks["x"]
assert ts.previous == "executing"
assert ts in ws.executing
assert ts not in ws.long_running

instructions = ws.handle_stimulus(
SecedeEvent(key="x", compute_duration=1, stimulus_id="s4")
)
assert not instructions # Do not send RescheduleMsg
assert ts.previous == "long-running"
assert ts not in ws.executing
assert ts in ws.long_running

if resume_to_executing:
instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["y"], stimulus_id="s5"),
ComputeTaskEvent.dummy("x", stimulus_id="s6"),
)
# Inform the scheduler of the SecedeEvent that happened in the past
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6")
]
assert ts.state == "long-running"
assert ts not in ws.executing
assert ts in ws.long_running

ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s7"))
assert ts not in ws.executing
assert ts not in ws.long_running
62 changes: 51 additions & 11 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ class LongRunningMsg(SendMessageToScheduler):

__slots__ = ("key", "compute_duration")
key: str
compute_duration: float
compute_duration: float | None


@dataclass
Expand Down Expand Up @@ -2024,24 +2024,39 @@ def _transition_resumed_waiting(
See also
--------
_transition_cancelled_fetch
_transition_cancelled_long_running
_transition_cancelled_waiting
_transition_resumed_fetch
"""
# None of the exit events of execute or gather_dep recommend a transition to
# waiting
assert not ts.done
if ts.previous in ("executing", "long-running"):
if ts.previous == "executing":
assert ts.next == "fetch"
# We're back where we started. We should forget about the entire
# cancellation attempt
ts.state = ts.previous
ts.state = "executing"
ts.next = None
ts.previous = None
elif self.validate:
return {}, []

elif ts.previous == "long-running":
assert ts.next == "fetch"
# Same as executing, and in addition send the LongRunningMsg in arrears
# Note that, if the task seceded before it was cancelled, this will cause
# the message to be sent twice.
ts.state = "long-running"
ts.next = None
ts.previous = None
smsg = LongRunningMsg(
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
)
return {}, [smsg]

else:
assert ts.previous == "flight"
assert ts.next == "waiting"

return {}, []
return {}, []

def _transition_cancelled_fetch(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2078,17 +2093,29 @@ def _transition_cancelled_waiting(
See also
--------
_transition_cancelled_fetch
_transition_cancelled_long_running
_transition_resumed_fetch
_transition_resumed_waiting
"""
# None of the exit events of gather_dep or execute recommend a transition to
# waiting
assert not ts.done
if ts.previous in ("executing", "long-running"):
if ts.previous == "executing":
# Forget the task was cancelled to begin with
ts.state = ts.previous
ts.state = "executing"
ts.previous = None
return {}, []
elif ts.previous == "long-running":
# Forget the task was cancelled to begin with, and inform the scheduler
# in arrears that it has seceded.
# Note that, if the task seceded before it was cancelled, this will cause
# the message to be sent twice.
ts.state = "long-running"
ts.previous = None
smsg = LongRunningMsg(
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
)
return {}, [smsg]
else:
assert ts.previous == "flight"
ts.state = "resumed"
Expand Down Expand Up @@ -2193,6 +2220,18 @@ def _transition_executing_long_running(
self._ensure_computing(),
)

def _transition_cancelled_long_running(
self, ts: TaskState, compute_duration: float, *, stimulus_id: str
) -> RecsInstrs:
"""This transition also serves resumed(fetch) -> long-running"""
assert ts.previous in ("executing", "long-running")
ts.previous = "long-running"
self.executing.discard(ts)
self.long_running.add(ts)

# Do not send LongRunningMsg
return self._ensure_computing()

def _transition_executing_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
) -> RecsInstrs:
Expand Down Expand Up @@ -2299,13 +2338,15 @@ def _transition_released_forgotten(
] = {
("cancelled", "fetch"): _transition_cancelled_fetch,
("cancelled", "error"): _transition_cancelled_released,
("cancelled", "long-running"): _transition_cancelled_long_running,
("cancelled", "memory"): _transition_cancelled_released,
("cancelled", "missing"): _transition_cancelled_released,
("cancelled", "released"): _transition_cancelled_released,
("cancelled", "rescheduled"): _transition_cancelled_released,
("cancelled", "waiting"): _transition_cancelled_waiting,
("resumed", "memory"): _transition_resumed_memory,
("resumed", "error"): _transition_resumed_error,
("resumed", "long-running"): _transition_cancelled_long_running,
("resumed", "released"): _transition_resumed_released,
("resumed", "waiting"): _transition_resumed_waiting,
("resumed", "fetch"): _transition_resumed_fetch,
Expand Down Expand Up @@ -2834,10 +2875,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
@_handle_event.register
def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs:
ts = self.tasks.get(ev.key)
if ts and ts.state == "executing":
return {ts: ("long-running", ev.compute_duration)}, []
else:
if not ts:
return {}, []
return {ts: ("long-running", ev.compute_duration)}, []

@_handle_event.register
def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs:
Expand Down

0 comments on commit a6557e5

Please sign in to comment.