From a9c240156544c99a58e6a6925ae34db7ffbc1096 Mon Sep 17 00:00:00 2001 From: Gabe Joseph Date: Wed, 14 Dec 2022 11:29:28 -0700 Subject: [PATCH] Select queued tasks in stimuli, not transitions (#7402) --- distributed/collections.py | 2 +- distributed/scheduler.py | 121 ++++++++++++++------------ distributed/tests/test_collections.py | 6 ++ distributed/tests/test_scheduler.py | 83 +++++++++++++++++- 4 files changed, 149 insertions(+), 63 deletions(-) diff --git a/distributed/collections.py b/distributed/collections.py index 4b67807ed48..4ce0fcefa41 100644 --- a/distributed/collections.py +++ b/distributed/collections.py @@ -121,7 +121,7 @@ def peekn(self, n: int) -> Iterator[T]: """Iterate over the n smallest elements without removing them. This is O(1) for n == 1; O(n*logn) otherwise. """ - if n <= 0: + if n <= 0 or not self: return # empty iterator if n == 1: yield self.peek() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dbaa7cfa1c6..089da240f1e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -2317,18 +2317,13 @@ def transition_processing_memory( ############################ # Update State Information # ############################ - recommendations: Recs = {} - client_msgs: Msgs = {} - if nbytes is not None: ts.set_nbytes(nbytes) - # NOTE: recommendations for queued tasks are added first, so they'll be popped - # last, allowing higher-priority downstream tasks to be transitioned first. - # FIXME: this would be incorrect if queued tasks are user-annotated as higher - # priority. - self._exit_processing_common(ts, recommendations) + self._exit_processing_common(ts) + recommendations: Recs = {} + client_msgs: Msgs = {} self._add_to_memory( ts, ws, recommendations, client_msgs, type=type, typename=typename ) @@ -2507,7 +2502,7 @@ def transition_processing_released(self, key: str, stimulus_id: str) -> RecsMsgs assert not ts.waiting_on assert ts.state == "processing" - ws = self._exit_processing_common(ts, recommendations) + ws = self._exit_processing_common(ts) if ws: worker_msgs[ws.address] = [ { @@ -2574,7 +2569,7 @@ def transition_processing_erred( assert ws ws.actors.remove(ts) - self._exit_processing_common(ts, recommendations) + self._exit_processing_common(ts) ts.erred_on.add(worker) if exception is not None: @@ -3059,24 +3054,20 @@ def remove_all_replicas(self, ts: TaskState) -> None: self.replicated_tasks.remove(ts) ts.who_has.clear() - def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs: - """Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can - handle. + def bulk_schedule_unrunnable_after_adding_worker(self, ws: WorkerState) -> Recs: + """Send ``no-worker`` tasks to ``processing`` that this worker can handle. Returns priority-ordered recommendations. """ - maybe_runnable = list(self._next_queued_tasks_for_worker(ws))[::-1] - - # Schedule any restricted tasks onto the new worker, if the worker can run them + runnable: list[TaskState] = [] for ts in self.unrunnable: valid = self.valid_workers(ts) if valid is None or ws in valid: - maybe_runnable.append(ts) + runnable.append(ts) # Recommendations are processed LIFO, hence the reversed order - maybe_runnable.sort(key=operator.attrgetter("priority"), reverse=True) - # Note not all will necessarily be run; transition->processing will decide - return {ts.key: "processing" for ts in maybe_runnable} + runnable.sort(key=operator.attrgetter("priority"), reverse=True) + return {ts.key: "processing" for ts in runnable} def _validate_ready(self, ts: TaskState) -> None: """Validation for ready states (processing, queued, no-worker)""" @@ -3108,9 +3099,7 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs: return {ws.address: [self._task_to_msg(ts)]} - def _exit_processing_common( - self, ts: TaskState, recommendations: Recs - ) -> WorkerState | None: + def _exit_processing_common(self, ts: TaskState) -> WorkerState | None: """Remove *ts* from the set of processing tasks. Returns @@ -3133,28 +3122,8 @@ def _exit_processing_common( self.check_idle_saturated(ws) self.release_resources(ts, ws) - for qts in self._next_queued_tasks_for_worker(ws): - if self.validate: - assert qts.key not in recommendations, recommendations[qts.key] - recommendations[qts.key] = "processing" - return ws - def _next_queued_tasks_for_worker(self, ws: WorkerState) -> Iterator[TaskState]: - """Queued tasks to run, in priority order, on all open slots on a worker""" - if not self.queued or ws.status != Status.running: - return - - # NOTE: this is called most frequently because a single task has completed, so - # there are <= 1 task slots available on the worker. - # `peekn` has fast paths for the cases N<=0 and N==1. - for qts in self.queued.peekn(_task_slots_available(ws, self.WORKER_SATURATION)): - if self.validate: - assert qts.state == "queued", qts.state - assert not qts.processing_on - assert not qts.waiting_on - yield qts - def _add_to_memory( self, ts: TaskState, @@ -4240,7 +4209,10 @@ async def add_worker( logger.exception(e) if ws.status == Status.running: - self.transitions(self.bulk_schedule_after_adding_worker(ws), stimulus_id) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) logger.info("Register worker %s", ws) @@ -4620,6 +4592,42 @@ def update_graph( # TODO: balance workers + def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: + """Respond to an event which may have opened spots on worker threadpools + + Selects the appropriate number of tasks from the front of the queue according to + the total number of task slots available on workers (potentially 0), and + transitions them to ``processing``. + + Notes + ----- + Other transitions related to this stimulus should be fully processed beforehand, + so any tasks that became runnable are already in ``processing``. Otherwise, + overproduction can occur if queued tasks get scheduled before downstream tasks. + + Must be called after `check_idle_saturated`; i.e. `idle_task_count` must be up + to date. + """ + if not self.queued: + return + slots_available = sum( + _task_slots_available(ws, self.WORKER_SATURATION) + for ws in self.idle_task_count + ) + if slots_available == 0: + return + + recommendations: Recs = {} + for qts in self.queued.peekn(slots_available): + if self.validate: + assert qts.state == "queued", qts.state + assert not qts.processing_on, (qts, qts.processing_on) + assert not qts.waiting_on, (qts, qts.processing_on) + assert qts.who_wants or qts.waiters, qts + recommendations[qts.key] = "processing" + + self.transitions(recommendations, stimulus_id) + def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4946,6 +4954,8 @@ def client_releases_keys(self, keys=None, client=None, stimulus_id=None): self._client_releases_keys(keys=keys, cs=cs, recommendations=recommendations) self.transitions(recommendations, stimulus_id) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) + def client_heartbeat(self, client=None): """Handle heartbeats from Client""" cs: ClientState = self.clients[client] @@ -5290,15 +5300,18 @@ def handle_task_finished( ) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) - self.send_all(client_msgs, worker_msgs) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) + def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: r: tuple = self.stimulus_task_erred(key=key, stimulus_id=stimulus_id, **msg) recommendations, client_msgs, worker_msgs = r self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) + def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: ts = self.tasks.get(key) ws = self.workers.get(worker) @@ -5340,13 +5353,7 @@ def handle_long_running( ws.add_to_long_running(ts) self.check_idle_saturated(ws) - recommendations: Recs = { - qts.key: "processing" for qts in self._next_queued_tasks_for_worker(ws) - } - if self.validate: - assert len(recommendations) <= 1, (ws, recommendations) - - self.transitions(recommendations, stimulus_id) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) def handle_worker_status_change( self, status: str | Status, worker: str | WorkerState, stimulus_id: str @@ -5372,12 +5379,10 @@ def handle_worker_status_change( if ws.status == Status.running: self.running.add(ws) self.check_idle_saturated(ws) - recs = self.bulk_schedule_after_adding_worker(ws) - if recs: - client_msgs: Msgs = {} - worker_msgs: Msgs = {} - self._transitions(recs, client_msgs, worker_msgs, stimulus_id) - self.send_all(client_msgs, worker_msgs) + self.transitions( + self.bulk_schedule_unrunnable_after_adding_worker(ws), stimulus_id + ) + self.stimulus_queue_slots_maybe_opened(stimulus_id=stimulus_id) else: self.running.discard(ws) self.idle.pop(ws.address, None) diff --git a/distributed/tests/test_collections.py b/distributed/tests/test_collections.py index 066cf147a33..6db1811072d 100644 --- a/distributed/tests/test_collections.py +++ b/distributed/tests/test_collections.py @@ -148,8 +148,14 @@ def test_heapset(): assert list(heap.peekn(1)) == [cx] heap.remove(cw) assert list(heap.peekn(1)) == [cx] + heap.remove(cx) + assert list(heap.peekn(-1)) == [] + assert list(heap.peekn(0)) == [] + assert list(heap.peekn(1)) == [] + assert list(heap.peekn(2)) == [] # Test resilience to failure in key() + heap.add(cx) bad_key = C("bad_key", 0) del bad_key.i with pytest.raises(AttributeError): diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7dc53b822ce..8a5374e515e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import itertools import json import logging import math @@ -86,14 +87,14 @@ async def test_administration(s, a, b): @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_respect_data_in_memory(c, s, a): - x = delayed(inc)(1) - y = delayed(inc)(x) + x = delayed(inc)(1, dask_key_name="x") + y = delayed(inc)(x, dask_key_name="y") f = c.persist(y) await wait([f]) assert s.tasks[y.key].who_has == {s.workers[a.address]} - z = delayed(operator.add)(x, y) + z = delayed(operator.add)(x, y, dask_key_name="z") f2 = c.persist(z) while f2.key not in s.tasks or not s.tasks[f2.key]: assert s.tasks[y.key].who_has @@ -371,6 +372,80 @@ def __del__(self): assert max(Refcount.log) <= s.total_nthreads +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_forget_tasks_while_processing(c, s, a, b): + events = [Event() for _ in range(10)] + + futures = c.map(Event.wait, events) + await events[0].set() + await futures[0] + await c.close() + assert not s.tasks + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[("", 1)], Worker=Nanny) +async def test_restart_while_processing(c, s, a, b): + events = [Event() for _ in range(10)] + + futures = c.map(Event.wait, events) + await events[0].set() + await futures[0] + # TODO slow because worker waits a while for the task to finish + await c.restart() + assert not s.tasks + + +@gen_cluster( + client=True, + nthreads=[("", 1)] * 3, + config={"distributed.scheduler.worker-saturation": 1.0}, +) +async def test_queued_release_multiple_workers(c, s, *workers): + async with Client(s.address, asynchronous=True) as c2: + event = Event(client=c2) + + rootish_threshold = s.total_nthreads * 2 + 1 + + first_batch = c.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"first-{i}" for i in range(rootish_threshold)], + ) + await async_wait_for(lambda: s.queued, 5) + + second_batch = c2.map( + lambda i: event.wait(), + range(rootish_threshold), + key=[f"second-{i}" for i in range(rootish_threshold)], + fifo_timeout=0, + ) + await async_wait_for(lambda: second_batch[0].key in s.tasks, 5) + + # All of the second batch should be queued after the first batch + assert [ts.key for ts in s.queued.sorted()] == [ + f.key + for f in itertools.chain(first_batch[s.total_nthreads :], second_batch) + ] + + # Cancel the first batch. + # Use `Client.close` instead of `del first_batch` because deleting futures sends cancellation + # messages one at a time. We're testing here that when multiple workers have open slots, we don't + # recommend the same queued tasks for every worker, so we need a bulk cancellation operation. + await c.close() + del c, first_batch + + await async_wait_for(lambda: len(s.tasks) == len(second_batch), 5) + + # Second batch should move up the queue and start processing + assert len(s.queued) == len(second_batch) - s.total_nthreads, list( + s.queued.sorted() + ) + + await event.set() + await c2.gather(second_batch) + + @gen_cluster( client=True, nthreads=[("", 2)] * 2, @@ -4237,7 +4312,7 @@ def assert_rootish(): await asyncio.sleep(0.005) assert_rootish() if rootish: - assert all(s.tasks[k] in s.queued for k in keys) + assert all(s.tasks[k] in s.queued for k in keys), [s.tasks[k] for k in keys] await block.set() # At this point we need/want to wait for the task-finished message to # arrive on the scheduler. There is no proper hook to wait, therefore we