Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler.reschedule() works only by accident #6339

Merged
merged 6 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,21 +2223,19 @@ def transition_waiting_released(self, key, stimulus_id):
pdb.set_trace()
raise

def transition_processing_released(self, key, stimulus_id):
def transition_processing_released(self, key: str, stimulus_id: str):
try:
ts: TaskState = self.tasks[key]
dts: TaskState
recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
ts = self.tasks[key]
recommendations = {}
worker_msgs = {}

if self.validate:
assert ts.processing_on
assert not ts.who_has
assert not ts.waiting_on
assert self.tasks[key].state == "processing"
assert ts.state == "processing"

w: str = _remove_from_processing(self, ts)
w = _remove_from_processing(self, ts)
if w:
worker_msgs[w] = [
{
Expand Down Expand Up @@ -2265,7 +2263,7 @@ def transition_processing_released(self, key, stimulus_id):
if self.validate:
assert not ts.processing_on

return recommendations, client_msgs, worker_msgs
return recommendations, {}, worker_msgs
except Exception as e:
logger.exception(e)
if LOG_PDB:
Expand Down Expand Up @@ -6606,7 +6604,9 @@ async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]:

transition_story = story

def reschedule(self, key=None, worker=None):
def reschedule(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out-of-scope comment: Is there a better name that highlights that reschedule does not actually reschedule, i.e., it does not schedule the task somewhere else, it merely cancels the previous scheduling decision? For example, deschedule might be better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well no. A transition to released will automatically kick the task back to waiting.
The functional description of the method is accurate. The "released" bit is an implementation detail.
I'm adding a comment to explain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense from that perspective, I was thinking that the function doesn't include the rescheduling bit, but in the end you're right, the releasing/descheduling automatically achieves the aim of rescheduling.

self, key: str, worker: str | None = None, *, stimulus_id: str
) -> None:
"""Reschedule a task
Things may have shifted and this task may now be better suited to run
Expand All @@ -6616,15 +6616,17 @@ def reschedule(self, key=None, worker=None):
ts = self.tasks[key]
except KeyError:
logger.warning(
"Attempting to reschedule task {}, which was not "
"found on the scheduler. Aborting reschedule.".format(key)
f"Attempting to reschedule task {key}, which was not "
"found on the scheduler. Aborting reschedule."
)
return
if ts.state != "processing":
return
if worker and ts.processing_on.address != worker:
return
self.transitions({key: "released"}, f"reschedule-{time()}")
# transition_processing_released will immediately suggest an additional
# transition to waiting if the task has any waiters or clients holding a future.
self.transitions({key: "released"}, stimulus_id=stimulus_id)

#####################
# Utility functions #
Expand Down
2 changes: 1 addition & 1 deletion distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
*_log_msg,
)
)
self.scheduler.reschedule(key)
self.scheduler.reschedule(key, stimulus_id=stimulus_id)
# Victim had already started execution
elif state in _WORKER_STATE_REJECT:
self.log(("already-computing", *_log_msg))
Expand Down
14 changes: 8 additions & 6 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4299,11 +4299,12 @@ async def test_retire_many_workers(c, s, *workers):
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 3)] * 2,
config={"distributed.scheduler.default-task-durations": {"f": "10ms"}},
config={
"distributed.scheduler.work-stealing": False,
"distributed.scheduler.default-task-durations": {"f": "10ms"},
},
)
async def test_weight_occupancy_against_data_movement(c, s, a, b):
await s.extensions["stealing"].stop()

def f(x, y=0, z=0):
sleep(0.01)
return x
Expand All @@ -4322,11 +4323,12 @@ def f(x, y=0, z=0):
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)],
config={"distributed.scheduler.default-task-durations": {"f": "10ms"}},
config={
"distributed.scheduler.work-stealing": False,
"distributed.scheduler.default-task-durations": {"f": "10ms"},
},
)
async def test_distribute_tasks_by_nthreads(c, s, a, b):
await s.extensions["stealing"].stop()

def f(x, y=0):
sleep(0.01)
return x
Expand Down
123 changes: 123 additions & 0 deletions distributed/tests/test_reschedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Tests for tasks raising the Reschedule exception and Scheduler.reschedule().

Note that this functionality is also used by work stealing;
see test_steal.py for additional tests.
"""
from __future__ import annotations

import asyncio
from time import sleep

import pytest

from distributed import Event, Reschedule, get_worker, secede, wait
from distributed.utils_test import captured_logger, gen_cluster, slowinc
from distributed.worker_state_machine import (
ComputeTaskEvent,
FreeKeysEvent,
RescheduleEvent,
SecedeEvent,
)


@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.scheduler.work-stealing": False},
)
async def test_scheduler_reschedule(c, s, a, b):
xs = c.map(slowinc, range(100), key="x", delay=0.1)
while not a.state.tasks or not b.state.tasks:
await asyncio.sleep(0.01)
assert len(a.state.tasks) == len(b.state.tasks) == 50

ys = c.map(slowinc, range(100), key="y", delay=0.1, workers=[a.address])
while len(a.state.tasks) != 150:
await asyncio.sleep(0.01)

# Reschedule the 50 xs that are processing on a
for x in xs:
if s.tasks[x.key].processing_on is s.workers[a.address]:
s.reschedule(x.key, stimulus_id="test")

# Wait for at least some of the 50 xs that had been scheduled on a to move to b.
# This happens because you have 100 ys processing on a and 50 xs processing on b,
# so the scheduler will prefer b for the rescheduled tasks to obtain more equal
# balancing.
while len(a.state.tasks) == 150 or len(b.state.tasks) <= 50:
await asyncio.sleep(0.01)


@gen_cluster()
async def test_scheduler_reschedule_warns(s, a, b):
with captured_logger("distributed.scheduler") as sched:
s.reschedule(key="__this-key-does-not-exist__", stimulus_id="test")

assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()


@pytest.mark.parametrize("long_running", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)] * 2,
config={"distributed.scheduler.work-stealing": False},
)
async def test_raise_reschedule(c, s, a, b, long_running):
"""A task raises Reschedule()"""
a_address = a.address

def f(x):
if long_running:
secede()
sleep(0.1)
if get_worker().address == a_address:
raise Reschedule()

futures = c.map(f, range(4), key=["x1", "x2", "x3", "x4"])
futures2 = c.map(slowinc, range(10), delay=0.1, key="clog", workers=[a.address])
await wait(futures)
assert any(isinstance(ev, RescheduleEvent) for ev in a.state.stimulus_log)
assert all(f.key in b.data for f in futures)


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the two tests below are new

@pytest.mark.parametrize("long_running", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_cancelled_reschedule(c, s, a, long_running):
"""A task raises Reschedule(), but the future was released by the client"""
ev1 = Event()
ev2 = Event()

def f(ev1, ev2):
if long_running:
secede()
ev1.set()
ev2.wait()
raise Reschedule()

x = c.submit(f, ev1, ev2, key="x")
await ev1.wait()
x.release()
while "x" in s.tasks:
await asyncio.sleep(0.01)

await ev2.set()
while "x" in a.state.tasks:
await asyncio.sleep(0.01)


@pytest.mark.parametrize("long_running", [False, True])
def test_cancelled_reschedule_worker_state(ws, long_running):
"""Same as test_cancelled_reschedule"""

ws.handle_stimulus(ComputeTaskEvent.dummy(key="x", stimulus_id="s1"))
if long_running:
ws.handle_stimulus(SecedeEvent(key="x", compute_duration=1.0, stimulus_id="s2"))

instructions = ws.handle_stimulus(
FreeKeysEvent(keys=["x"], stimulus_id="s3"),
RescheduleEvent(key="x", stimulus_id="s4"),
)
# There's no RescheduleMsg and the task has been forgotten
assert not instructions
assert not ws.tasks
41 changes: 7 additions & 34 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,9 +1132,12 @@ async def test_balance_many_workers(c, s, *workers):


@nodebug
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30)
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 30,
config={"distributed.scheduler.work-stealing": False},
)
async def test_balance_many_workers_2(c, s, *workers):
await s.extensions["stealing"].stop()
futures = c.map(slowinc, range(90), delay=0.2)
await wait(futures)
assert {len(w.has_what) for w in s.workers.values()} == {3}
Expand Down Expand Up @@ -1513,35 +1516,6 @@ async def test_log_tasks_during_restart(c, s, a, b):
assert "exit" in str(s.events)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule(c, s, a, b):
await c.submit(slowinc, -1, delay=0.1) # learn cost
x = c.map(slowinc, range(4), delay=0.1)

# add much more work onto worker a
futures = c.map(slowinc, range(10, 20), delay=0.1, workers=a.address)

while len(s.tasks) < len(x) + len(futures):
await asyncio.sleep(0.001)

for future in x:
s.reschedule(key=future.key)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test remained green if you removed these lines


# Worker b gets more of the original tasks
await wait(x)
assert sum(future.key in b.data for future in x) >= 3
assert sum(future.key in a.data for future in x) <= 1


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule_warns(c, s, a, b):
with captured_logger(logging.getLogger("distributed.scheduler")) as sched:
s.reschedule(key="__this-key-does-not-exist__")

assert "not found on the scheduler" in sched.getvalue()
assert "Aborting reschedule" in sched.getvalue()


@gen_cluster(client=True)
async def test_get_task_status(c, s, a, b):
future = c.submit(inc, 1)
Expand Down Expand Up @@ -3342,12 +3316,11 @@ async def test_worker_heartbeat_after_cancel(c, s, *workers):

@gen_cluster(client=True, nthreads=[("", 1)] * 2)
async def test_set_restrictions(c, s, a, b):

f = c.submit(inc, 1, workers=[b.address])
f = c.submit(inc, 1, key="f", workers=[b.address])
await f
s.set_restrictions(worker={f.key: a.address})
assert s.tasks[f.key].worker_restrictions == {a.address}
s.reschedule(f)
Copy link
Collaborator Author

@crusaderky crusaderky Jun 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was silently doing nothing (future f is not in s.tasks). The task was executed only once, on b.

await b.close()
await f


Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers):

steal.move_task_request(victim_ts, wsA, wsB)

s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
await c.gather(futs1)

del futs1
Expand Down Expand Up @@ -1238,7 +1238,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
steal.move_task_request(victim_ts, wsA, wsB)

s.set_restrictions(worker={victim_key: [wsB.address]})
s.reschedule(victim_key)
s.reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
Expand Down
18 changes: 0 additions & 18 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
Client,
Event,
Nanny,
Reschedule,
default_client,
get_client,
get_worker,
Expand Down Expand Up @@ -1181,23 +1180,6 @@ def some_name():
assert result.startswith("some_name")


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
async def test_reschedule(c, s, a, b):
await s.extensions["stealing"].stop()
a_address = a.address

def f(x):
sleep(0.1)
if get_worker().address == a_address:
raise Reschedule()

futures = c.map(f, range(4))
futures2 = c.map(slowinc, range(10), delay=0.1, workers=a.address)
await wait(futures)

assert all(f.key in b.data for f in futures)


@gen_cluster(nthreads=[])
async def test_deque_handler(s):
from distributed.worker import logger
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,7 +2236,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent | No
)

if isinstance(result["actual-exception"], Reschedule):
return RescheduleEvent(key=ts.key, stimulus_id=stimulus_id)
return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}")

logger.warning(
"Compute Failed\n"
Expand Down
Loading