Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cancelled/resumed->long-running transitions #6916

Merged
merged 1 commit into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4854,7 +4854,7 @@ def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(
self, key: str, worker: str, compute_duration: float, stimulus_id: str
self, key: str, worker: str, compute_duration: float | None, stimulus_id: str
) -> None:
"""A task has seceded from the thread pool

Expand All @@ -4874,11 +4874,12 @@ def handle_long_running(
logger.debug("Received long-running signal from duplicate task. Ignoring.")
return

old_duration = ts.prefix.duration_average
if old_duration < 0:
ts.prefix.duration_average = compute_duration
else:
ts.prefix.duration_average = (old_duration + compute_duration) / 2
if compute_duration is not None:
old_duration = ts.prefix.duration_average
if old_duration < 0:
ts.prefix.duration_average = compute_duration
else:
ts.prefix.duration_average = (old_duration + compute_duration) / 2

occ = ws.processing[ts]
ws.occupancy -= occ
Expand Down
115 changes: 113 additions & 2 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
GatherDepFailureEvent,
GatherDepNetworkFailureEvent,
GatherDepSuccessEvent,
LongRunningMsg,
RescheduleEvent,
SecedeEvent,
TaskFinishedMsg,
UpdateDataEvent,
)
Expand Down Expand Up @@ -640,7 +642,12 @@ def test_workerstate_executing_to_executing(ws_with_running_task):
FreeKeysEvent(keys=["x"], stimulus_id="s1"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s2"),
)
assert not instructions
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
]
assert ws.tasks["x"] is ts
assert ts.state == prev_state

Expand Down Expand Up @@ -821,7 +828,12 @@ def test_workerstate_resumed_fetch_to_executing(ws_with_running_task):
FreeKeysEvent(keys=["y", "x"], stimulus_id="s3"),
ComputeTaskEvent.dummy("x", resource_restrictions={"R": 1}, stimulus_id="s4"),
)
assert not instructions
if prev_state == "executing":
assert not instructions
else:
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s4")
]
assert ws.tasks["x"].state == prev_state


Expand Down Expand Up @@ -946,3 +958,102 @@ def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls):
ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5"))
assert "y" not in ws.tasks
assert ws.tasks["x"].state == "memory"


@pytest.mark.parametrize("resume_to_fetch", [False, True])
@pytest.mark.parametrize("resume_to_executing", [False, True])
@pytest.mark.parametrize(
"done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent, RescheduleEvent]
)
def test_secede_cancelled_or_resumed_workerstate(
ws, resume_to_fetch, resume_to_executing, done_ev_cls
):
"""Test what happens when a cancelled or resumed(fetch) task calls secede().
See also test_secede_cancelled_or_resumed_scheduler
"""
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s1"),
FreeKeysEvent(keys=["x"], stimulus_id="s2"),
)
if resume_to_fetch:
ws.handle_stimulus(
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s3"),
)
ts = ws.tasks["x"]
assert ts.previous == "executing"
assert ts in ws.executing
assert ts not in ws.long_running

instructions = ws.handle_stimulus(
SecedeEvent(key="x", compute_duration=1, stimulus_id="s4")
)
assert not instructions # Do not send RescheduleMsg
assert ts.previous == "long-running"
assert ts not in ws.executing
assert ts in ws.long_running

if resume_to_executing:
instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["y"], stimulus_id="s5"),
ComputeTaskEvent.dummy("x", stimulus_id="s6"),
)
# Inform the scheduler of the SecedeEvent that happened in the past
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s6")
]
assert ts.state == "long-running"
assert ts not in ws.executing
assert ts in ws.long_running

ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s7"))
assert ts not in ws.executing
assert ts not in ws.long_running


@gen_cluster(client=True, nthreads=[("", 1)], timeout=2)
async def test_secede_cancelled_or_resumed_scheduler(c, s, a):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this second test does not test the resumed(fetch) use case. However, the first test above demonstrates that the cancelled and resumed(fetch) use cases are indistinguishable from the scheduler's side.

"""Same as test_secede_cancelled_or_resumed_workerstate, but testing the interaction
with the scheduler
"""
ws = s.workers[a.address]
ev1 = Event()
ev2 = Event()
ev3 = Event()
ev4 = Event()

def f(ev1, ev2, ev3, ev4):
ev1.set()
ev2.wait()
distributed.secede()
ev3.set()
ev4.wait()
return 123

x = c.submit(f, ev1, ev2, ev3, ev4, key="x")
await ev1.wait()
ts = a.state.tasks["x"]
assert ts.state == "executing"
assert sum(ws.processing.values()) > 0

x.release()
await wait_for_state("x", "cancelled", a)
assert not ws.processing

await ev2.set()
await ev3.wait()
assert ts.previous == "long-running"
assert not ws.processing

x = c.submit(inc, 1, key="x")
await wait_for_state("x", "long-running", a)

# Test that the scheduler receives a delayed {op: long-running}
assert ws.processing
while sum(ws.processing.values()):
await asyncio.sleep(0.1)
assert ws.processing

await ev4.set()
assert await x == 123
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, so the expected, correct behavior is that you release a future, submit a new future with the same key, and get back the old (cancelled) future's result instead of the new one? That seems pretty wrong to me.

I'm aware that this could happen even for normal tasks, not just long-running, and it's just a consequence of not cancelling the thread, and keeping the TaskState around until the thread finishes. But from an API and user perspective, that seems wrong. I didn't think keys needed to be unique over the lifetime of the cluster, just that they needed to be unique among all the currently-active keys (and once a client saw a key as released, then it could safely consider it inactive).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, but this is how it works. I spent several weeks trying and failing to make it become sensible: #6844

This is a pretty rare use case: a user submits a task with a manually-defined key; then before the task has had the time to finish, it submits a different task with the same key.
Honestly, I feel that the blame should sit on the user entirely here, and figuring out what went wrong should be pretty straightforward. It also should not really happen except when prototyping from a notebook, unless there are key collisions which will cause all sort of weird behaviour anyway.

assert not ws.processing
10 changes: 9 additions & 1 deletion distributed/tests/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
FreeKeysEvent,
LongRunningMsg,
RescheduleEvent,
TaskFinishedMsg,
)
Expand Down Expand Up @@ -565,14 +566,21 @@ def test_resumed_with_different_resources(ws_with_running_task, done_ev_cls):
"""
ws = ws_with_running_task
assert ws.available_resources == {"R": 0}
ts = ws.tasks["x"]
prev_state = ts.state

ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s1"))
assert ws.available_resources == {"R": 0}

instructions = ws.handle_stimulus(
ComputeTaskEvent.dummy("x", stimulus_id="s2", resource_restrictions={"R": 0.4})
)
assert not instructions
if prev_state == "long-running":
assert instructions == [
LongRunningMsg(key="x", compute_duration=None, stimulus_id="s2")
]
else:
assert not instructions
assert ws.available_resources == {"R": 0}

ws.handle_stimulus(done_ev_cls.dummy(key="x", stimulus_id="s3"))
Expand Down
84 changes: 72 additions & 12 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ class LongRunningMsg(SendMessageToScheduler):

__slots__ = ("key", "compute_duration")
key: str
compute_duration: float
compute_duration: float | None


@dataclass
Expand Down Expand Up @@ -2077,24 +2077,39 @@ def _transition_resumed_waiting(
See also
--------
_transition_cancelled_fetch
_transition_cancelled_or_resumed_long_running
_transition_cancelled_waiting
_transition_resumed_fetch
"""
# None of the exit events of execute or gather_dep recommend a transition to
# waiting
assert not ts.done
if ts.previous in ("executing", "long-running"):
if ts.previous == "executing":
assert ts.next == "fetch"
# We're back where we started. We should forget about the entire
# cancellation attempt
ts.state = ts.previous
ts.state = "executing"
ts.next = None
ts.previous = None
elif self.validate:
return {}, []

elif ts.previous == "long-running":
assert ts.next == "fetch"
# Same as executing, and in addition send the LongRunningMsg in arrears
# Note that, if the task seceded before it was cancelled, this will cause
# the message to be sent twice.
ts.state = "long-running"
ts.next = None
ts.previous = None
smsg = LongRunningMsg(
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
)
return {}, [smsg]

else:
assert ts.previous == "flight"
assert ts.next == "waiting"

return {}, []
return {}, []

def _transition_cancelled_fetch(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2131,17 +2146,29 @@ def _transition_cancelled_waiting(
See also
--------
_transition_cancelled_fetch
_transition_cancelled_or_resumed_long_running
_transition_resumed_fetch
_transition_resumed_waiting
"""
# None of the exit events of gather_dep or execute recommend a transition to
# waiting
assert not ts.done
if ts.previous in ("executing", "long-running"):
if ts.previous == "executing":
# Forget the task was cancelled to begin with
ts.state = ts.previous
ts.state = "executing"
ts.previous = None
return {}, []
elif ts.previous == "long-running":
# Forget the task was cancelled to begin with, and inform the scheduler
# in arrears that it has seceded.
# Note that, if the task seceded before it was cancelled, this will cause
# the message to be sent twice.
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
ts.state = "long-running"
ts.previous = None
smsg = LongRunningMsg(
key=ts.key, compute_duration=None, stimulus_id=stimulus_id
)
return {}, [smsg]
else:
assert ts.previous == "flight"
ts.state = "resumed"
Expand Down Expand Up @@ -2234,6 +2261,11 @@ def _transition_flight_released(
def _transition_executing_long_running(
self, ts: TaskState, compute_duration: float, *, stimulus_id: str
) -> RecsInstrs:
"""
See also
--------
_transition_cancelled_or_resumed_long_running
"""
ts.state = "long-running"
self.executing.discard(ts)
self.long_running.add(ts)
Expand All @@ -2246,6 +2278,34 @@ def _transition_executing_long_running(
self._ensure_computing(),
)

def _transition_cancelled_or_resumed_long_running(
self, ts: TaskState, compute_duration: float, *, stimulus_id: str
) -> RecsInstrs:
"""Handles transitions:

- cancelled(executing) -> long-running
- cancelled(long-running) -> long-running (user called secede() twice)
- resumed(executing->fetch) -> long-running
- resumed(long-running->fetch) -> long-running (user called secede() twice)

Unlike in the executing->long_running transition, do not send LongRunningMsg.
From the scheduler's perspective, this task no longer exists (cancelled) or is
in memory on another worker (resumed). So it shouldn't hear about it.
Instead, we're going to send the LongRunningMsg when and if the task
transitions back to waiting.

See also
--------
_transition_executing_long_running
_transition_cancelled_waiting
_transition_resumed_waiting
"""
assert ts.previous in ("executing", "long-running")
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
ts.previous = "long-running"
self.executing.discard(ts)
self.long_running.add(ts)
return self._ensure_computing()

def _transition_executing_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
) -> RecsInstrs:
Expand Down Expand Up @@ -2352,15 +2412,16 @@ def _transition_released_forgotten(
] = {
("cancelled", "error"): _transition_cancelled_released,
("cancelled", "fetch"): _transition_cancelled_fetch,
("cancelled", "long-running"): _transition_cancelled_or_resumed_long_running,
("cancelled", "memory"): _transition_cancelled_released,
("cancelled", "missing"): _transition_cancelled_released,
("cancelled", "released"): _transition_cancelled_released,
("cancelled", "rescheduled"): _transition_cancelled_released,
("cancelled", "waiting"): _transition_cancelled_waiting,
("resumed", "error"): _transition_resumed_error,
("resumed", "fetch"): _transition_resumed_fetch,
("resumed", "long-running"): _transition_cancelled_or_resumed_long_running,
("resumed", "memory"): _transition_resumed_memory,
("resumed", "missing"): _transition_resumed_missing,
("resumed", "released"): _transition_resumed_released,
("resumed", "rescheduled"): _transition_resumed_rescheduled,
("resumed", "waiting"): _transition_resumed_waiting,
Expand Down Expand Up @@ -2898,10 +2959,9 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs:
@_handle_event.register
def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs:
ts = self.tasks.get(ev.key)
if ts and ts.state == "executing":
return {ts: ("long-running", ev.compute_duration)}, []
else:
if not ts:
return {}, []
return {ts: ("long-running", ev.compute_duration)}, []

@_handle_event.register
def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs:
Expand Down