Skip to content

Commit

Permalink
Select queued tasks in stimuli, not transitions (#7402)
Browse files Browse the repository at this point in the history
  • Loading branch information
gjoseph92 authored Dec 14, 2022
1 parent 06bf5b3 commit a9c2401
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 63 deletions.
2 changes: 1 addition & 1 deletion distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def peekn(self, n: int) -> Iterator[T]:
"""Iterate over the n smallest elements without removing them.
This is O(1) for n == 1; O(n*logn) otherwise.
"""
if n <= 0:
if n <= 0 or not self:
return # empty iterator
if n == 1:
yield self.peek()
Expand Down
121 changes: 63 additions & 58 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,18 +2317,13 @@ def transition_processing_memory(
############################
# Update State Information #
############################
recommendations: Recs = {}
client_msgs: Msgs = {}

if nbytes is not None:
ts.set_nbytes(nbytes)

# NOTE: recommendations for queued tasks are added first, so they'll be popped
# last, allowing higher-priority downstream tasks to be transitioned first.
# FIXME: this would be incorrect if queued tasks are user-annotated as higher
# priority.
self._exit_processing_common(ts, recommendations)
self._exit_processing_common(ts)

recommendations: Recs = {}
client_msgs: Msgs = {}
self._add_to_memory(
ts, ws, recommendations, client_msgs, type=type, typename=typename
)
Expand Down Expand Up @@ -2507,7 +2502,7 @@ def transition_processing_released(self, key: str, stimulus_id: str) -> RecsMsgs
assert not ts.waiting_on
assert ts.state == "processing"

ws = self._exit_processing_common(ts, recommendations)
ws = self._exit_processing_common(ts)
if ws:
worker_msgs[ws.address] = [
{
Expand Down Expand Up @@ -2574,7 +2569,7 @@ def transition_processing_erred(
assert ws
ws.actors.remove(ts)

self._exit_processing_common(ts, recommendations)
self._exit_processing_common(ts)

ts.erred_on.add(worker)
if exception is not None:
Expand Down Expand Up @@ -3059,24 +3054,20 @@ def remove_all_replicas(self, ts: TaskState) -> None:
self.replicated_tasks.remove(ts)
ts.who_has.clear()

def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs:
"""Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can
handle.
def bulk_schedule_unrunnable_after_adding_worker(self, ws: WorkerState) -> Recs:
"""Send ``no-worker`` tasks to ``processing`` that this worker can handle.
Returns priority-ordered recommendations.
"""
maybe_runnable = list(self._next_queued_tasks_for_worker(ws))[::-1]

# Schedule any restricted tasks onto the new worker, if the worker can run them
runnable: list[TaskState] = []
for ts in self.unrunnable:
valid = self.valid_workers(ts)
if valid is None or ws in valid:
maybe_runnable.append(ts)
runnable.append(ts)

# Recommendations are processed LIFO, hence the reversed order
maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True)
# Note not all will necessarily be run; transition->processing will decide
return {ts.key: "processing" for ts in maybe_runnable}
runnable.sort(key=operator.attrgetter("priority"), reverse=True)
return {ts.key: "processing" for ts in runnable}

def _validate_ready(self, ts: TaskState) -> None:
"""Validation for ready states (processing, queued, no-worker)"""
Expand Down Expand Up @@ -3108,9 +3099,7 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:

return {ws.address: [self._task_to_msg(ts)]}

def _exit_processing_common(
self, ts: TaskState, recommendations: Recs
) -> WorkerState | None:
def _exit_processing_common(self, ts: TaskState) -> WorkerState | None:
"""Remove *ts* from the set of processing tasks.
Returns
Expand All @@ -3133,28 +3122,8 @@ def _exit_processing_common(
self.check_idle_saturated(ws)
self.release_resources(ts, ws)

for qts in self._next_queued_tasks_for_worker(ws):
if self.validate:
assert qts.key not in recommendations, recommendations[qts.key]
recommendations[qts.key] = "processing"

return ws

def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]:
"""Queued tasks to run, in priority order, on all open slots on a worker"""
if not self.queued or ws.status != Status.running:
return

# NOTE: this is called most frequently because a single task has completed, so
# there are <= 1 task slots available on the worker.
# `peekn` has fast paths for the cases N<=0 and N==1.
for qts in self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)):
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on
assert not qts.waiting_on
yield qts

def _add_to_memory(
self,
ts: TaskState,
Expand Down Expand Up @@ -4240,7 +4209,10 @@ async def add_worker(
logger.exception(e)

if ws.status == Status.running:
self.transitions(self.bulk_schedule_after_adding_worker(ws), stimulus_id)
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

logger.info("Register worker %s", ws)

Expand Down Expand Up @@ -4620,6 +4592,42 @@ def update_graph(

# TODO: balance workers

def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None:
"""Respond to an event which may have opened spots on worker threadpools
Selects the appropriate number of tasks from the front of the queue according to
the total number of task slots available on workers (potentially 0), and
transitions them to ``processing``.
Notes
-----
Other transitions related to this stimulus should be fully processed beforehand,
so any tasks that became runnable are already in ``processing``. Otherwise,
overproduction can occur if queued tasks get scheduled before downstream tasks.
Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up
to date.
"""
if not self.queued:
return
slots_available = sum(
_task_slots_available(ws, self.WORKER_SATURATION)
for ws in self.idle_task_count
)
if slots_available == 0:
return

recommendations: Recs = {}
for qts in self.queued.peekn(slots_available):
if self.validate:
assert qts.state == "queued", qts.state
assert not qts.processing_on, (qts, qts.processing_on)
assert not qts.waiting_on, (qts, qts.processing_on)
assert qts.who_wants or qts.waiters, qts
recommendations[qts.key] = "processing"

self.transitions(recommendations, stimulus_id)

def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)
Expand Down Expand Up @@ -4946,6 +4954,8 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None):
self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations)
self.transitions(recommendations, stimulus_id)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def client_heartbeat(self, client=None):
"""Handle heartbeats from Client"""
cs: ClientState = self.clients[client]
Expand Down Expand Up @@ -5290,15 +5300,18 @@ def handle_task_finished(
)
recommendations, client_msgs, worker_msgs = r
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)

self.send_all(client_msgs, worker_msgs)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None:
r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg)
recommendations, client_msgs, worker_msgs = r
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)

self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
ts = self.tasks.get(key)
ws = self.workers.get(worker)
Expand Down Expand Up @@ -5340,13 +5353,7 @@ def handle_long_running(
ws.add_to_long_running(ts)
self.check_idle_saturated(ws)

recommendations: Recs = {
qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws)
}
if self.validate:
assert len(recommendations) <= 1, (ws, recommendations)

self.transitions(recommendations, stimulus_id)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)

def handle_worker_status_change(
self, status: str | Status, worker: str | WorkerState, stimulus_id: str
Expand All @@ -5372,12 +5379,10 @@ def handle_worker_status_change(
if ws.status == Status.running:
self.running.add(ws)
self.check_idle_saturated(ws)
recs = self.bulk_schedule_after_adding_worker(ws)
if recs:
client_msgs: Msgs = {}
worker_msgs: Msgs = {}
self._transitions(recs, client_msgs, worker_msgs, stimulus_id)
self.send_all(client_msgs, worker_msgs)
self.transitions(
self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id
)
self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id)
else:
self.running.discard(ws)
self.idle.pop(ws.address, None)
Expand Down
6 changes: 6 additions & 0 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,14 @@ def test_heapset():
assert list(heap.peekn(1)) == [cx]
heap.remove(cw)
assert list(heap.peekn(1)) == [cx]
heap.remove(cx)
assert list(heap.peekn(-1)) == []
assert list(heap.peekn(0)) == []
assert list(heap.peekn(1)) == []
assert list(heap.peekn(2)) == []

# Test resilience to failure in key()
heap.add(cx)
bad_key = C("bad_key", 0)
del bad_key.i
with pytest.raises(AttributeError):
Expand Down
83 changes: 79 additions & 4 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import itertools
import json
import logging
import math
Expand Down Expand Up @@ -86,14 +87,14 @@ async def test_administration(s, a, b):

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_respect_data_in_memory(c, s, a):
x = delayed(inc)(1)
y = delayed(inc)(x)
x = delayed(inc)(1, dask_key_name="x")
y = delayed(inc)(x, dask_key_name="y")
f = c.persist(y)
await wait([f])

assert s.tasks[y.key].who_has == {s.workers[a.address]}

z = delayed(operator.add)(x, y)
z = delayed(operator.add)(x, y, dask_key_name="z")
f2 = c.persist(z)
while f2.key not in s.tasks or not s.tasks[f2.key]:
assert s.tasks[y.key].who_has
Expand Down Expand Up @@ -371,6 +372,80 @@ def __del__(self):
assert max(Refcount.log) <= s.total_nthreads


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_forget_tasks_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]

futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
await c.close()
assert not s.tasks


@pytest.mark.slow
@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny)
async def test_restart_while_processing(c, s, a, b):
events = [Event() for _ in range(10)]

futures = c.map(Event.wait, events)
await events[0].set()
await futures[0]
# TODO slow because worker waits a while for the task to finish
await c.restart()
assert not s.tasks


@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config={"distributed.scheduler.worker-saturation": 1.0},
)
async def test_queued_release_multiple_workers(c, s, *workers):
async with Client(s.address, asynchronous=True) as c2:
event = Event(client=c2)

rootish_threshold = s.total_nthreads * 2 + 1

first_batch = c.map(
lambda i: event.wait(),
range(rootish_threshold),
key=[f"first-{i}" for i in range(rootish_threshold)],
)
await async_wait_for(lambda: s.queued, 5)

second_batch = c2.map(
lambda i: event.wait(),
range(rootish_threshold),
key=[f"second-{i}" for i in range(rootish_threshold)],
fifo_timeout=0,
)
await async_wait_for(lambda: second_batch[0].key in s.tasks, 5)

# All of the second batch should be queued after the first batch
assert [ts.key for ts in s.queued.sorted()] == [
f.key
for f in itertools.chain(first_batch[s.total_nthreads :], second_batch)
]

# Cancel the first batch.
# Use `Client.close` instead of `del first_batch` because deleting futures sends cancellation
# messages one at a time. We're testing here that when multiple workers have open slots, we don't
# recommend the same queued tasks for every worker, so we need a bulk cancellation operation.
await c.close()
del c, first_batch

await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5)

# Second batch should move up the queue and start processing
assert len(s.queued) == len(second_batch) - s.total_nthreads, list(
s.queued.sorted()
)

await event.set()
await c2.gather(second_batch)


@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
Expand Down Expand Up @@ -4237,7 +4312,7 @@ def assert_rootish():
await asyncio.sleep(0.005)
assert_rootish()
if rootish:
assert all(s.tasks[k] in s.queued for k in keys)
assert all(s.tasks[k] in s.queued for k in keys), [s.tasks[k] for k in keys]
await block.set()
# At this point we need/want to wait for the task-finished message to
# arrive on the scheduler. There is no proper hook to wait, therefore we
Expand Down

0 comments on commit a9c2401

Please sign in to comment.