Skip to content

Commit

Permalink
Fix RescheduleMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 28, 2022
1 parent a8eb3b2 commit b3ffa87
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 40 deletions.
10 changes: 6 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6568,7 +6568,9 @@ async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:

transition_story = story

def reschedule(self, key=None, worker=None):
def reschedule(
self, key: str, worker: str | None = None, *, stimulus_id: str
) -> None:
"""Reschedule a task
Things may have shifted and this task may now be better suited to run
Expand All @@ -6578,15 +6580,15 @@ def reschedule(self, key=None, worker=None):
ts = self.tasks[key]
except KeyError:
logger.warning(
"Attempting to reschedule task {}, which was not "
"found on the scheduler. Aborting reschedule.".format(key)
f"Attempting to reschedule task {key}, which was not "
"found on the scheduler. Aborting reschedule."
)
return
if ts.state != "processing":
return
if worker and ts.processing_on.address != worker:
return
self.transitions({key: "released"}, f"reschedule-{time()}")
self.transitions({key: "released"}, stimulus_id=stimulus_id)

#####################
# Utility functions #
Expand Down
2 changes: 1 addition & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
*_log_msg,
)
)
self.scheduler.reschedule(key)
self.scheduler.reschedule(key, stimulus_id=stimulus_id)
# Victim had already started execution
elif state in _WORKER_STATE_REJECT:
self.log(("already-computing", *_log_msg))
Expand Down
44 changes: 23 additions & 21 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,30 +1514,33 @@ async def test_log_tasks_during_restart(c, s, a, b):
assert "exit" in str(s.events)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.scheduler.work-stealing": False},
)
async def test_reschedule(c, s, a, b):
await c.submit(slowinc, -1, delay=0.1) # learn cost
x = c.map(slowinc, range(4), delay=0.1)

# add much more work onto worker a
futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address)
xs = c.map(slowinc, range(100), key="x", delay=0.1)
while not a.state.tasks or not b.state.tasks:
await asyncio.sleep(0.01)
assert len(a.state.tasks) == len(b.state.tasks) == 50

while len(s.tasks) < len(x) + len(futures):
await asyncio.sleep(0.001)
ys = c.map(slowinc, range(100), key="y", delay=0.1, workers=[a.address])
while len(a.state.tasks) != 150:
await asyncio.sleep(0.01)

for future in x:
s.reschedule(key=future.key)
for x in xs:
if s.tasks[x.key].processing_on is s.workers[a.address]:
s.reschedule(x.key, stimulus_id="test")

# Worker b gets more of the original tasks
await wait(x)
assert sum(future.key in b.data for future in x) >= 3
assert sum(future.key in a.data for future in x) <= 1
while len(a.state.tasks) == 150 or len(b.state.tasks) <= 50:
await asyncio.sleep(0.01)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
s.reschedule(key="__this-key-does-not-exist__")
@gen_cluster()
async def test_reschedule_warns(s, a, b):
with captured_logger("distributed.scheduler") as sched:
s.reschedule(key="__this-key-does-not-exist__", stimulus_id="test")

assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()
Expand Down Expand Up @@ -3343,12 +3346,11 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers):

@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_set_restrictions(c, s, a, b):

f = c.submit(inc, 1, workers=[b.address])
f = c.submit(inc, 1, key="f", workers=[b.address])
await f
s.set_restrictions(worker={f.key: a.address})
assert s.tasks[f.key].worker_restrictions == {a.address}
s.reschedule(f)
await b.close()
await f


Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers):

steal.move_task_request(victim_ts, wsA, wsB)

s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
await c.gather(futs1)

del futs1
Expand Down Expand Up @@ -1238,7 +1238,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
steal.move_task_request(victim_ts, wsA, wsB)

s.set_restrictions(worker={victim_key: [wsB.address]})
s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
Expand Down
13 changes: 9 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
RescheduleEvent,
SerializedTask,
StealRequestEvent,
)
Expand Down Expand Up @@ -1180,20 +1181,24 @@ def some_name():
assert result.startswith("some_name")


@pytest.mark.slow
@pytest.mark.parametrize("long_running", [False, True])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule(c, s, a, b):
async def test_reschedule(c, s, a, b, long_running):
await s.extensions["stealing"].stop()
a_address = a.address

def f(x):
if long_running:
distributed.secede()
sleep(0.1)
if get_worker().address == a_address:
raise Reschedule()

futures = c.map(f, range(4))
futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address)
futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"])
futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address])
await wait(futures)

assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log)
assert all(f.key in b.data for f in futures)


Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,7 +2218,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
)

if isinstance(result["actual-exception"], Reschedule):
return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id)
return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}")

logger.warning(
"Compute Failed\n"
Expand Down
9 changes: 2 additions & 7 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1693,9 +1693,7 @@ def _transition_waiting_constrained(
def _transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)
return recs, [smsg]
return {ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]

def _transition_executing_rescheduled(
self, ts: TaskState, *, stimulus_id: str
Expand All @@ -1705,10 +1703,7 @@ def _transition_executing_rescheduled(
self.executing.discard(ts)

return merge_recs_instructions(
(
{ts: "released"},
[RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)],
),
({ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]),
self._ensure_computing(),
)

Expand Down

0 comments on commit b3ffa87

Please sign in to comment.