Skip to content

Commit

Permalink
Remove RescheduleMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 27, 2022
1 parent a8eb3b2 commit 611c85d
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 70 deletions.
28 changes: 12 additions & 16 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,7 +3031,6 @@ def __init__(
"release-worker-data": self.release_worker_data,
"add-keys": self.add_keys,
"long-running": self.handle_long_running,
"reschedule": self.reschedule,
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
Expand Down Expand Up @@ -6568,25 +6567,22 @@ async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:

transition_story = story

def reschedule(self, key=None, worker=None):
def reschedule(self, key: str, *, stimulus_id: str) -> None:
"""Reschedule a task
Things may have shifted and this task may now be better suited to run
elsewhere
elsewhere.
Note
----
This handler is used by work stealing exclusively.
When a task raises the Reschedule exception, it causes a transition to released
which, as soon as it reaches the scheduler, causes the task to immediately
transition back to waiting.
"""
try:
ts = self.tasks[key]
except KeyError:
logger.warning(
"Attempting to reschedule task {}, which was not "
"found on the scheduler. Aborting reschedule.".format(key)
)
return
if ts.state != "processing":
return
if worker and ts.processing_on.address != worker:
return
self.transitions({key: "released"}, f"reschedule-{time()}")
if self.validate:
assert self.tasks[key].state == "processing"
self.transitions({key: "released"}, stimulus_id)

#####################
# Utility functions #
Expand Down
2 changes: 1 addition & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
*_log_msg,
)
)
self.scheduler.reschedule(key)
self.scheduler.reschedule(key, stimulus_id=stimulus_id)
# Victim had already started execution
elif state in _WORKER_STATE_REJECT:
self.log(("already-computing", *_log_msg))
Expand Down
45 changes: 19 additions & 26 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,33 +1514,27 @@ async def test_log_tasks_during_restart(c, s, a, b):
assert "exit" in str(s.events)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.scheduler.work-stealing": False},
)
async def test_reschedule(c, s, a, b):
await c.submit(slowinc, -1, delay=0.1) # learn cost
x = c.map(slowinc, range(4), delay=0.1)

# add much more work onto worker a
futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address)

while len(s.tasks) < len(x) + len(futures):
await asyncio.sleep(0.001)

for future in x:
s.reschedule(key=future.key)

# Worker b gets more of the original tasks
await wait(x)
assert sum(future.key in b.data for future in x) >= 3
assert sum(future.key in a.data for future in x) <= 1
xs = c.map(slowinc, range(100), key="x", delay=0.1)
while not a.state.tasks or not b.state.tasks:
await asyncio.sleep(0.01)
assert len(a.state.tasks) == len(b.state.tasks) == 50

ys = c.map(slowinc, range(100), key="y", delay=0.1, workers=[a.address])
while len(a.state.tasks) != 150:
await asyncio.sleep(0.01)

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
s.reschedule(key="__this-key-does-not-exist__")
for x in xs:
if s.tasks[x.key].processing_on is s.workers[a.address]:
s.reschedule(x.key, stimulus_id="test")

assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()
while len(a.state.tasks) == 150 or len(b.state.tasks) <= 50:
await asyncio.sleep(0.01)


@gen_cluster(client=True)
Expand Down Expand Up @@ -3343,12 +3337,11 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers):

@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_set_restrictions(c, s, a, b):

f = c.submit(inc, 1, workers=[b.address])
f = c.submit(inc, 1, key="f", workers=[b.address])
await f
s.set_restrictions(worker={f.key: a.address})
assert s.tasks[f.key].worker_restrictions == {a.address}
s.reschedule(f)
await b.close()
await f


Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers):

steal.move_task_request(victim_ts, wsA, wsB)

s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
await c.gather(futs1)

del futs1
Expand Down Expand Up @@ -1238,7 +1238,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
steal.move_task_request(victim_ts, wsA, wsB)

s.set_restrictions(worker={victim_key: [wsB.address]})
s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
Expand Down
13 changes: 9 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
RescheduleEvent,
SerializedTask,
StealRequestEvent,
)
Expand Down Expand Up @@ -1180,20 +1181,24 @@ def some_name():
assert result.startswith("some_name")


@pytest.mark.slow
@pytest.mark.parametrize("long_running", [False, True])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule(c, s, a, b):
async def test_reschedule(c, s, a, b, long_running):
await s.extensions["stealing"].stop()
a_address = a.address

def f(x):
if long_running:
distributed.secede()
sleep(0.1)
if get_worker().address == a_address:
raise Reschedule()

futures = c.map(f, range(4))
futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address)
futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"])
futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address])
await wait(futures)

assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log)
assert all(f.key in b.data for f in futures)


Expand Down
5 changes: 2 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SerializedTask,
StateMachineEvent,
TaskState,
Expand Down Expand Up @@ -227,8 +226,8 @@ def test_sendmsg_to_dict():
def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
instr1 = RescheduleMsg(key="foo", stimulus_id="test")
instr2 = RescheduleMsg(key="bar", stimulus_id="test")
instr1 = ReleaseWorkerDataMsg(key="foo", stimulus_id="test")
instr2 = ReleaseWorkerDataMsg(key="bar", stimulus_id="test")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2218,7 +2218,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
)

if isinstance(result["actual-exception"], Reschedule):
return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id)
return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}")

logger.warning(
"Compute Failed\n"
Expand Down
24 changes: 7 additions & 17 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
key: str


# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
# FIXME temp hack - do not merge
@dataclass
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"

__slots__ = ("key",)
key: str
class DummyMsg(SendMessageToScheduler):
op = "dummy"
__slots__ = ()


@dataclass
Expand Down Expand Up @@ -770,7 +768,7 @@ class AlreadyCancelledEvent(StateMachineEvent):
key: str


# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
# Not to be confused with the distributed.Reschedule Exception
@dataclass
class RescheduleEvent(StateMachineEvent):
__slots__ = ("key",)
Expand Down Expand Up @@ -1693,9 +1691,7 @@ def _transition_waiting_constrained(
def _transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)
return recs, [smsg]
return {ts: "released"}, [DummyMsg(stimulus_id=stimulus_id)]

def _transition_executing_rescheduled(
self, ts: TaskState, *, stimulus_id: str
Expand All @@ -1704,13 +1700,7 @@ def _transition_executing_rescheduled(
self.available_resources[resource] += quantity
self.executing.discard(ts)

return merge_recs_instructions(
(
{ts: "released"},
[RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)],
),
self._ensure_computing(),
)
return {ts: "released"}, [DummyMsg(stimulus_id=stimulus_id)]

def _transition_waiting_ready(
self, ts: TaskState, *, stimulus_id: str
Expand Down

0 comments on commit 611c85d

Please sign in to comment.