Skip to content

Commit

Permalink
Auto-fail tasks with deps larger than the worker memory (dask#8135)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Sep 5, 2023
1 parent 7a005fc commit dd56cc6
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 13 deletions.
59 changes: 47 additions & 12 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,18 +2078,17 @@ def transition_released_waiting(self, key: str, stimulus_id: str) -> RecsMsgs:

def transition_no_worker_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]
worker_msgs: Msgs = {}

if self.validate:
assert not ts.actor, f"Actors can't be in `no-worker`: {ts}"
assert ts in self.unrunnable

if ws := self.decide_worker_non_rootish(ts):
self.unrunnable.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)
# If no worker, task just stays in `no-worker`

return {}, {}, worker_msgs
return {}, {}, {}

def decide_worker_rootish_queuing_disabled(
self, ts: TaskState
Expand Down Expand Up @@ -2295,8 +2294,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}

worker_msgs = self._add_to_processing(ts, ws)
return {}, {}, worker_msgs
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)

def transition_waiting_memory(
self,
Expand Down Expand Up @@ -2751,19 +2749,16 @@ def transition_queued_released(self, key: str, stimulus_id: str) -> RecsMsgs:

def transition_queued_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
ts = self.tasks[key]
recommendations: Recs = {}
worker_msgs: Msgs = {}

if self.validate:
assert not ts.actor, f"Actors can't be queued: {ts}"
assert ts in self.queued

if ws := self.decide_worker_rootish_queuing_enabled():
self.queued.discard(ts)
worker_msgs = self._add_to_processing(ts, ws)
return self._add_to_processing(ts, ws, stimulus_id=stimulus_id)
# If no worker, task just stays `queued`

return recommendations, {}, worker_msgs
return {}, {}, {}

def _remove_key(self, key: str) -> None:
ts = self.tasks.pop(key)
Expand Down Expand Up @@ -3144,7 +3139,9 @@ def _validate_ready(self, ts: TaskState) -> None:
assert ts not in self.queued
assert all(dts.who_has for dts in ts.dependencies)

def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:
def _add_to_processing(
self, ts: TaskState, ws: WorkerState, stimulus_id: str
) -> RecsMsgs:
"""Set a task as processing on a worker and return the worker messages to send"""
if self.validate:
self._validate_ready(ts)
Expand All @@ -3161,7 +3158,45 @@ def _add_to_processing(self, ts: TaskState, ws: WorkerState) -> Msgs:
if ts.actor:
ws.actors.add(ts)

return {ws.address: [self._task_to_msg(ts)]}
ndep_bytes = sum(dts.nbytes for dts in ts.dependencies)
if (
ws.memory_limit
and ndep_bytes > ws.memory_limit
and dask.config.get("distributed.worker.memory.terminate")
):
# Note
# ----
# This is a crude safety system, only meant to prevent order-of-magnitude
# fat-finger errors.
#
# For collection finalizers and in general most concat operations, it takes
# a lot less to kill off the worker; you'll just need
# ndep_bytes * 2 > ws.memory_limit * terminate threshold.
#
# In heterogeneous clusters with workers mounting different amounts of
# memory, the user is expected to manually set host/worker/resource
# restrictions.
msg = (
f"Task {ts.key} has {format_bytes(ndep_bytes)} worth of input "
f"dependencies, but worker {ws.address} has memory_limit set to "
f"{format_bytes(ws.memory_limit)}."
)
if ts.prefix.name == "finalize":
msg += (
" It seems like you called client.compute() on a huge collection. "
"Consider writing to distributed storage or slicing/reducing first."
)
logger.error(msg)
return self._transition(
ts.key,
"erred",
exception=pickle.dumps(MemoryError(msg)),
cause=ts.key,
stimulus_id=stimulus_id,
worker=ws.address,
)

return {}, {}, {ws.address: [self._task_to_msg(ts)]}

def _exit_processing_common(self, ts: TaskState) -> WorkerState | None:
"""Remove *ts* from the set of processing tasks.
Expand Down
39 changes: 38 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import operator
import pickle
import random
import re
import sys
from collections.abc import Collection
Expand All @@ -22,7 +23,7 @@
from tornado.ioloop import IOLoop

import dask
from dask import delayed
from dask import bag, delayed
from dask.core import flatten
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.utils import parse_timedelta, tmpfile, typename
Expand Down Expand Up @@ -4472,3 +4473,39 @@ async def test_scatter_creates_ts(c, s, a, b):
await a.close()
assert await x2 == 2
assert s.tasks["x"].run_spec is not None


@pytest.mark.parametrize("finalize", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)] * 4,
worker_kwargs={"memory_limit": "100 kB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
},
)
async def test_refuse_to_schedule_huge_task(c, s, *workers, finalize):
"""If the total size of a task's input grossly exceed the memory available on the
worker, the scheduler must refuse to compute it
"""
bg = bag.from_sequence(
[random.randbytes(30_000) for _ in range(4)],
npartitions=4,
)
match = r"worth of input dependencies, but worker .* has memory_limit set to"
if finalize:
fut = c.compute(bg)
match += r".* you called client.compute()"
else:
bg = bg.repartition(npartitions=1).persist()
fut = list(c.futures_of(bg))[0]

with pytest.raises(MemoryError, match=match):
await fut

# The task never reached the workers
for w in workers:
for ev in w.state.log:
assert fut.key not in ev
1 change: 1 addition & 0 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def f(ev):
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
"distributed.worker.memory.terminate": False,
},
),
)
Expand Down

0 comments on commit dd56cc6

Please sign in to comment.