Skip to content

Commit

Permalink
Merge branch 'main' into support-stimulus-id-in-reschedule
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 13, 2022
2 parents e02e34c + 50d2911 commit d3dfa55
Show file tree
Hide file tree
Showing 11 changed files with 215 additions and 184 deletions.
11 changes: 6 additions & 5 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4418,16 +4418,14 @@ async def _get_task_stream(
else:
return msgs

async def _register_scheduler_plugin(self, plugin, name, **kwargs):
if isinstance(plugin, type):
plugin = plugin(**kwargs)

async def _register_scheduler_plugin(self, plugin, name, idempotent=False):
return await self.scheduler.register_scheduler_plugin(
plugin=dumps(plugin, protocol=4),
name=name,
idempotent=idempotent,
)

def register_scheduler_plugin(self, plugin, name=None):
def register_scheduler_plugin(self, plugin, name=None, idempotent=False):
"""Register a scheduler plugin.
See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins
Expand All @@ -4439,6 +4437,8 @@ def register_scheduler_plugin(self, plugin, name=None):
name : str
Name for the plugin; if None, a name is taken from the
plugin instance or automatically generated if not present.
idempotent : bool
Do not re-register if a plugin of the given name already exists.
"""
if name is None:
name = _get_plugin_name(plugin)
Expand All @@ -4447,6 +4447,7 @@ def register_scheduler_plugin(self, plugin, name=None):
self._register_scheduler_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
)

def register_worker_callbacks(self, setup=None):
Expand Down
80 changes: 38 additions & 42 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def transition_waiting_processing(self, key, stimulus_id):
assert not ts.processing_on
assert not ts.has_lost_dependencies
assert ts not in self.unrunnable
assert all([dts.who_has for dts in ts.dependencies])
assert all(dts.who_has for dts in ts.dependencies)

ws = self.decide_worker(ts)
if ws is None:
Expand Down Expand Up @@ -4683,53 +4683,48 @@ def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None:
self.send_all(client_msgs, worker_msgs)

def handle_missing_data(
self, key: str, errant_worker: str, stimulus_id: str, **kwargs
self, key: str, worker: str, errant_worker: str, stimulus_id: str
) -> None:
"""Signal that `errant_worker` does not hold `key`
"""Signal that `errant_worker` does not hold `key`.
This may either indicate that `errant_worker` is dead or that we may be
working with stale data and need to remove `key` from the workers
`has_what`.
If no replica of a task is available anymore, the task is transitioned
back to released and rescheduled, if possible.
This may either indicate that `errant_worker` is dead or that we may be working
with stale data and need to remove `key` from the workers `has_what`. If no
replica of a task is available anymore, the task is transitioned back to
released and rescheduled, if possible.
Parameters
----------
key : str, optional
Task key that could not be found, by default None
errant_worker : str, optional
Address of the worker supposed to hold a replica, by default None
key : str
Task key that could not be found
worker : str
Address of the worker informing the scheduler
errant_worker : str
Address of the worker supposed to hold a replica
"""
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
logger.debug(f"handle missing data {key=} {worker=} {errant_worker=}")
self.log_event(errant_worker, {"action": "missing-data", "key": key})

if key not in self.tasks:
ts = self.tasks.get(key)
ws = self.workers.get(errant_worker)
if not ts or not ws or ws not in ts.who_has:
return

ts: TaskState = self.tasks[key]
ws: WorkerState = self.workers.get(errant_worker)

if ws is not None and ws in ts.who_has:
self.remove_replica(ts, ws)
self.remove_replica(ts, ws)
if ts.state == "memory" and not ts.who_has:
if ts.run_spec:
self.transitions({key: "released"}, stimulus_id)
else:
self.transitions({key: "forgotten"}, stimulus_id)

def release_worker_data(self, key, worker, stimulus_id):
ws: WorkerState = self.workers.get(worker)
ts: TaskState = self.tasks.get(key)
if not ws or not ts:
def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
ts = self.tasks.get(key)
ws = self.workers.get(worker)
if not ts or not ws or ws not in ts.who_has:
return
recommendations: dict = {}
if ws in ts.who_has:
self.remove_replica(ts, ws)
if not ts.who_has:
recommendations[ts.key] = "released"
if recommendations:
self.transitions(recommendations, stimulus_id)

self.remove_replica(ts, ws)
if not ts.who_has:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(self, key=None, worker=None, compute_duration=None):
"""A task has seceded from the thread pool
Expand Down Expand Up @@ -4907,7 +4902,7 @@ async def register_scheduler_plugin(self, plugin, name=None, idempotent=None):

self.add_plugin(plugin, name=name, idempotent=idempotent)

def worker_send(self, worker, msg):
def worker_send(self, worker: str, msg: dict[str, Any]) -> None:
"""Send message to worker
This also handles connection failures by adding a callback to remove
Expand Down Expand Up @@ -7127,6 +7122,9 @@ def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
Expand Down Expand Up @@ -7331,21 +7329,19 @@ def _task_to_msg(
"priority": ts.priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
"who_has": {
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

deps = ts.dependencies
if deps:
msg["who_has"] = {dts.key: [ws.address for ws in dts.who_has] for dts in deps}
msg["nbytes"] = {dts.key: dts.nbytes for dts in deps}

if state.validate:
assert all(msg["who_has"].values())

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
Expand Down Expand Up @@ -7482,7 +7478,7 @@ def validate_task_state(ts: TaskState) -> None:
assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state)

if ts.state == "processing":
assert all([dts.who_has for dts in ts.dependencies]), (
assert all(dts.who_has for dts in ts.dependencies), (
"task processing without all deps",
str(ts),
str(ts.dependencies),
Expand Down
57 changes: 42 additions & 15 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Container
from math import log2
from time import time
from typing import TYPE_CHECKING, ClassVar, TypedDict

from tlz import topk
from tornado.ioloop import PeriodicCallback
Expand All @@ -18,6 +19,10 @@
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.utils import log_errors, recursive_to_dict

if TYPE_CHECKING:
# Recursive imports
from distributed.scheduler import Scheduler, TaskState, WorkerState

# Stealing requires multiple network bounces and if successful also task
# submission which may include code serialization. Therefore, be very
# conservative in the latency estimation to suppress too aggressive stealing
Expand Down Expand Up @@ -48,18 +53,40 @@
}


class InFlightInfo(TypedDict):
victim: WorkerState
thief: WorkerState
victim_duration: float
thief_duration: float
stimulus_id: str


class WorkStealing(SchedulerPlugin):
def __init__(self, scheduler):
scheduler: Scheduler
# ({ task states for level 0}, ..., {task states for level 14})
stealable_all: tuple[set[TaskState], ...]
# {worker: ({ task states for level 0}, ..., {task states for level 14})}
stealable: dict[str, tuple[set[TaskState], ...]]
# { task state: (worker, level) }
key_stealable: dict[TaskState, tuple[str, int]]
# (multiplier for level 0, ... multiplier for level 14)
cost_multipliers: ClassVar[tuple[float, ...]] = (1.0,) + tuple(
1 + 2 ** (i - 6) for i in range(1, 15)
)
_callback_time: float | None
count: int
# { task state: <stealing info dict> }
in_flight: dict[TaskState, InFlightInfo]
# { worker state: occupancy }
in_flight_occupancy: defaultdict[WorkerState, float]
_in_flight_event: asyncio.Event
_request_counter: int

def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
# { level: { task states } }
self.stealable_all = [set() for i in range(15)]
# { worker: { level: { task states } } }
self.stealable = dict()
# { task state: (worker, level) }
self.key_stealable = dict()

self.cost_multipliers = [1 + 2 ** (i - 6) for i in range(15)]
self.cost_multipliers[0] = 1
self.stealable_all = tuple(set() for _ in range(15))
self.stealable = {}
self.key_stealable = {}

for worker in scheduler.workers:
self.add_worker(worker=worker)
Expand All @@ -72,9 +99,7 @@ def __init__(self, scheduler):
self.scheduler.add_plugin(self)
self.scheduler.events["stealing"] = deque(maxlen=100000)
self.count = 0
# { task state: <stealing info dict> }
self.in_flight = dict()
# { worker state: occupancy }
self.in_flight = {}
self.in_flight_occupancy = defaultdict(lambda: 0)
self._in_flight_event = asyncio.Event()
self._request_counter = 0
Expand Down Expand Up @@ -121,7 +146,7 @@ def log(self, msg):
return self.scheduler.log_event("stealing", msg)

def add_worker(self, scheduler=None, worker=None):
self.stealable[worker] = [set() for i in range(15)]
self.stealable[worker] = tuple(set() for _ in range(15))

def remove_worker(self, scheduler=None, worker=None):
del self.stealable[worker]
Expand Down Expand Up @@ -213,7 +238,9 @@ def steal_time_ratio(self, ts):

return cost_multiplier, level

def move_task_request(self, ts, victim, thief) -> str:
def move_task_request(
self, ts: TaskState, victim: WorkerState, thief: WorkerState
) -> str:
try:
if ts in self.in_flight:
return "in-flight"
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def f(ev):
assert_story(
a.story("f1"),
[
("f1", "compute-task"),
("f1", "compute-task", "released"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def test_worker_story_with_deps(c, s, a, b):

# This is a simple transition log
expected = [
("res", "compute-task"),
("res", "compute-task", "released"),
("res", "released", "waiting", "waiting", {"dep": "fetch"}),
("res", "waiting", "ready", "ready", {"res": "executing"}),
("res", "ready", "executing", "executing", {}),
Expand Down
41 changes: 14 additions & 27 deletions distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from distributed import Client, Nanny, wait
from distributed.chaos import KillWorker
from distributed.compatibility import WINDOWS
from distributed.config import config
from distributed.metrics import time
from distributed.utils import CancelledError
from distributed.utils_test import (
Expand Down Expand Up @@ -121,58 +120,46 @@ async def create_and_destroy_worker(delay):
assert await c.compute(z) == 8000884.93


@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=60)
@gen_cluster(nthreads=[("", 1)] * 10, client=True)
async def test_stress_scatter_death(c, s, *workers):
import random

s.allowed_failures = 1000
np = pytest.importorskip("numpy")
L = await c.scatter([np.random.random(10000) for i in range(len(workers))])
L = await c.scatter(
{f"scatter-{i}": np.random.random(10000) for i in range(len(workers))}
)
L = list(L.values())
await c.replicate(L, n=2)

adds = [
delayed(slowadd, pure=True)(
delayed(slowadd)(
random.choice(L),
random.choice(L),
delay=0.05,
dask_key_name="slowadd-1-%d" % i,
dask_key_name=f"slowadd-1-{i}",
)
for i in range(50)
]

adds = [
delayed(slowadd, pure=True)(a, b, delay=0.02, dask_key_name="slowadd-2-%d" % i)
delayed(slowadd)(a, b, delay=0.02, dask_key_name=f"slowadd-2-{i}")
for i, (a, b) in enumerate(sliding_window(2, adds))
]

futures = c.compute(adds)
L = adds = None

alive = list(workers)
del L
del adds

from distributed.scheduler import logger
for w in random.sample(workers, 7):
s.validate_state()
for w2 in workers:
w2.validate_state()

for i in range(7):
await asyncio.sleep(0.1)
try:
s.validate_state()
except Exception as c:
logger.exception(c)
if config.get("log-on-err"):
import pdb

pdb.set_trace()
else:
raise
w = random.choice(alive)
await w.close()
alive.remove(w)

with suppress(CancelledError):
await c.gather(futures)

futures = None


def vsum(*args):
return sum(args)
Expand Down
Loading

0 comments on commit d3dfa55

Please sign in to comment.