Skip to content

Commit

Permalink
Remove RescheduleMsg
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 27, 2022
1 parent 55a4fa8 commit 2e4522e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 58 deletions.
27 changes: 10 additions & 17 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3031,7 +3031,6 @@ def __init__(
"release-worker-data": self.release_worker_data,
"add-keys": self.add_keys,
"long-running": self.handle_long_running,
"reschedule": self.reschedule,
"keep-alive": lambda *args, **kwargs: None,
"log-event": self.log_worker_event,
"worker-status-change": self.handle_worker_status_change,
Expand Down Expand Up @@ -6568,26 +6567,20 @@ async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:

transition_story = story

def reschedule(
self, key: str, worker: str | None = None, *, stimulus_id: str
) -> None:
def reschedule(self, key: str, *, stimulus_id: str) -> None:
"""Reschedule a task
Things may have shifted and this task may now be better suited to run
elsewhere
elsewhere.
Note
----
This handler is used by work stealing exclusively.
When a task raises the Reschedule exception, it causes a transition to released
which, as soon as it reaches the scheduler, causes the task to immediately
transition back to waiting.
"""
try:
ts = self.tasks[key]
except KeyError:
logger.warning(
f"Attempting to reschedule task {key}, which was not "
"found on the scheduler. Aborting reschedule."
)
return
if ts.state != "processing":
return
if worker and ts.processing_on.address != worker:
return
assert self.tasks[key].state == "processing"
self.transitions({key: "released"}, stimulus_id)

#####################
Expand Down
45 changes: 20 additions & 25 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,33 +1514,29 @@ async def test_log_tasks_during_restart(c, s, a, b):
assert "exit" in str(s.events)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.scheduler.work-stealing": False},
)
async def test_reschedule(c, s, a, b):
await c.submit(slowinc, -1, delay=0.1) # learn cost
x = c.map(slowinc, range(4), delay=0.1)

# add much more work onto worker a
futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address)

while len(s.tasks) < len(x) + len(futures):
await asyncio.sleep(0.001)

for future in x:
s.reschedule(future.key, stimulus_id="test")
xs = c.map(slowinc, range(100), key="x", delay=0.1)
while not a.state.tasks:
await asyncio.sleep(0.01)
assert len(a.state.tasks) == 50

# Worker b gets more of the original tasks
await wait(x)
assert sum(future.key in b.data for future in x) >= 3
assert sum(future.key in a.data for future in x) <= 1
ys = c.map(slowinc, range(100), key="y", delay=0.1, workers=[a.address])
while len(a.state.tasks) != 150:
await asyncio.sleep(0.01)

assert len(b.state.tasks) == 50

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
s.reschedule("__this-key-does-not-exist__", stimulus_id="test")
for x in xs:
if x.key in a.state.tasks and s.tasks[x.key].state == "processing":
s.reschedule(x.key, stimulus_id="test")

assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()
while len(b.state.tasks) <= 50:
await asyncio.sleep(0.01)


@gen_cluster(client=True)
Expand Down Expand Up @@ -3343,12 +3339,11 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers):

@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_set_restrictions(c, s, a, b):

f = c.submit(inc, 1, workers=[b.address])
f = c.submit(inc, 1, key="f", workers=[b.address])
await f
s.set_restrictions(worker={f.key: a.address})
assert s.tasks[f.key].worker_restrictions == {a.address}
s.reschedule(f, stimulus_id="test")
await b.close()
await f


Expand Down
13 changes: 9 additions & 4 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
RemoveReplicasEvent,
RescheduleEvent,
SerializedTask,
StealRequestEvent,
)
Expand Down Expand Up @@ -1180,20 +1181,24 @@ def some_name():
assert result.startswith("some_name")


@pytest.mark.slow
@pytest.mark.parametrize("long_running", [False, True])
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule(c, s, a, b):
async def test_reschedule(c, s, a, b, long_running):
await s.extensions["stealing"].stop()
a_address = a.address

def f(x):
if long_running:
distributed.secede()
sleep(0.1)
if get_worker().address == a_address:
raise Reschedule()

futures = c.map(f, range(4))
futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address)
futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"])
futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address])
await wait(futures)

assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log)
assert all(f.key in b.data for f in futures)


Expand Down
5 changes: 2 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
RefreshWhoHasEvent,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
SerializedTask,
StateMachineEvent,
TaskState,
Expand Down Expand Up @@ -227,8 +226,8 @@ def test_sendmsg_to_dict():
def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
instr1 = RescheduleMsg(key="foo", stimulus_id="test")
instr2 = RescheduleMsg(key="bar", stimulus_id="test")
instr1 = ReleaseWorkerDataMsg(key="foo", stimulus_id="test")
instr2 = ReleaseWorkerDataMsg(key="bar", stimulus_id="test")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
Expand Down
16 changes: 7 additions & 9 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,11 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
key: str


# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
# FIXME temp hack - do not merge
@dataclass
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"

__slots__ = ("key",)
key: str
class DummyMsg(SendMessageToScheduler):
op = "dummy"
__slots__ = ()


@dataclass
Expand Down Expand Up @@ -770,7 +768,7 @@ class AlreadyCancelledEvent(StateMachineEvent):
key: str


# Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception
# Not to be confused with the distributed.Reschedule Exception
@dataclass
class RescheduleEvent(StateMachineEvent):
__slots__ = ("key",)
Expand Down Expand Up @@ -1693,7 +1691,7 @@ def _transition_waiting_constrained(
def _transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
return {ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]
return {ts: "released"}, [DummyMsg(stimulus_id=stimulus_id)]

def _transition_executing_rescheduled(
self, ts: TaskState, *, stimulus_id: str
Expand All @@ -1702,7 +1700,7 @@ def _transition_executing_rescheduled(
self.available_resources[resource] += quantity
self.executing.discard(ts)

return {ts: "released"}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]
return {ts: "released"}, [DummyMsg(stimulus_id=stimulus_id)]

def _transition_waiting_ready(
self, ts: TaskState, *, stimulus_id: str
Expand Down

0 comments on commit 2e4522e

Please sign in to comment.