Skip to content

Commit

Permalink
Validate and debug state machine on handle_compute_task (#6327)
Browse files Browse the repository at this point in the history
Partially closes #6305
  • Loading branch information
crusaderky authored May 13, 2022
1 parent 4b34bd4 commit 79d5a77
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 86 deletions.
23 changes: 12 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def transition_waiting_processing(self, key, stimulus_id):
assert not ts.processing_on
assert not ts.has_lost_dependencies
assert ts not in self.unrunnable
assert all([dts.who_has for dts in ts.dependencies])
assert all(dts.who_has for dts in ts.dependencies)

ws = self.decide_worker(ts)
if ws is None:
Expand Down Expand Up @@ -7125,6 +7125,9 @@ def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
Expand Down Expand Up @@ -7329,21 +7332,19 @@ def _task_to_msg(
"priority": ts.priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
"who_has": {
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

deps = ts.dependencies
if deps:
msg["who_has"] = {dts.key: [ws.address for ws in dts.who_has] for dts in deps}
msg["nbytes"] = {dts.key: dts.nbytes for dts in deps}

if state.validate:
assert all(msg["who_has"].values())

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
Expand Down Expand Up @@ -7480,7 +7481,7 @@ def validate_task_state(ts: TaskState) -> None:
assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state)

if ts.state == "processing":
assert all([dts.who_has for dts in ts.dependencies]), (
assert all(dts.who_has for dts in ts.dependencies), (
"task processing without all deps",
str(ts),
str(ts.dependencies),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def f(ev):
assert_story(
a.story("f1"),
[
("f1", "compute-task"),
("f1", "compute-task", "released"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def test_worker_story_with_deps(c, s, a, b):

# This is a simple transition log
expected = [
("res", "compute-task"),
("res", "compute-task", "released"),
("res", "released", "waiting", "waiting", {"dep": "fetch"}),
("res", "waiting", "ready", "ready", {"res": "executing"}),
("res", "ready", "executing", "executing", {}),
Expand Down
41 changes: 14 additions & 27 deletions distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from distributed import Client, Nanny, wait
from distributed.chaos import KillWorker
from distributed.compatibility import WINDOWS
from distributed.config import config
from distributed.metrics import time
from distributed.utils import CancelledError
from distributed.utils_test import (
Expand Down Expand Up @@ -121,58 +120,46 @@ async def create_and_destroy_worker(delay):
assert await c.compute(z) == 8000884.93


@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=60)
@gen_cluster(nthreads=[("", 1)] * 10, client=True)
async def test_stress_scatter_death(c, s, *workers):
import random

s.allowed_failures = 1000
np = pytest.importorskip("numpy")
L = await c.scatter([np.random.random(10000) for i in range(len(workers))])
L = await c.scatter(
{f"scatter-{i}": np.random.random(10000) for i in range(len(workers))}
)
L = list(L.values())
await c.replicate(L, n=2)

adds = [
delayed(slowadd, pure=True)(
delayed(slowadd)(
random.choice(L),
random.choice(L),
delay=0.05,
dask_key_name="slowadd-1-%d" % i,
dask_key_name=f"slowadd-1-{i}",
)
for i in range(50)
]

adds = [
delayed(slowadd, pure=True)(a, b, delay=0.02, dask_key_name="slowadd-2-%d" % i)
delayed(slowadd)(a, b, delay=0.02, dask_key_name=f"slowadd-2-{i}")
for i, (a, b) in enumerate(sliding_window(2, adds))
]

futures = c.compute(adds)
L = adds = None

alive = list(workers)
del L
del adds

from distributed.scheduler import logger
for w in random.sample(workers, 7):
s.validate_state()
for w2 in workers:
w2.validate_state()

for i in range(7):
await asyncio.sleep(0.1)
try:
s.validate_state()
except Exception as c:
logger.exception(c)
if config.get("log-on-err"):
import pdb

pdb.set_trace()
else:
raise
w = random.choice(alive)
await w.close()
alive.remove(w)

with suppress(CancelledError):
await c.gather(futures)

futures = None


def vsum(*args):
return sum(args)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,7 +3157,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):

sum_story = b.story("f1")
expected_sum_story = [
("f1", "compute-task"),
("f1", "compute-task", "released"),
(
"f1",
"released",
Expand All @@ -3174,7 +3174,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "waiting", "released", "released", lambda msg: msg["f1"] == "forgotten"),
("f1", "released", "forgotten", "forgotten", {}),
# Now, we actually compute the task *once*. This must not cycle back
("f1", "compute-task"),
("f1", "compute-task", "released"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,17 @@ async def test_fetch_to_compute(c, s, a, b):
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
[
(f2.key, "compute-task"),
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
# this, yet. We don't see the request-dep signal in here because we
# do not wait for the key to be actually scheduled
(f1.key, "ensure-task-exists", "released"),
# After the worker failed, we're instructed to forget f2 before
# something new comes in
("free-keys", (f2.key,)),
(f1.key, "compute-task"),
(f1.key, "compute-task", "released"),
(f1.key, "put-in-memory"),
(f2.key, "compute-task"),
(f2.key, "compute-task", "released"),
],
)

Expand Down
99 changes: 58 additions & 41 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,10 @@ def handle_acquire_replicas(
who_has: dict[str, Collection[str]],
stimulus_id: str,
) -> None:
if self.validate:
assert set(keys) == who_has.keys()
assert all(who_has.values())

recommendations: Recs = {}
for key in keys:
ts = self.ensure_task_exists(
Expand All @@ -1871,6 +1875,10 @@ def handle_acquire_replicas(
self.update_who_has(who_has)
self.transitions(recommendations, stimulus_id=stimulus_id)

if self.validate:
for key in keys:
assert self.tasks[key].state != "released", self.story(key)

def ensure_task_exists(
self, key: str, *, priority: tuple[int, ...], stimulus_id: str
) -> TaskState:
Expand All @@ -1891,19 +1899,18 @@ def handle_compute_task(
*,
key: str,
who_has: dict[str, Collection[str]],
nbytes: dict[str, int],
priority: tuple[int, ...],
duration: float,
function=None,
args=None,
kwargs=None,
task=no_value, # distributed.scheduler.TaskState.run_spec
nbytes: dict[str, int] | None = None,
resource_restrictions: dict[str, float] | None = None,
actor: bool = False,
annotations: dict | None = None,
stimulus_id: str,
) -> None:
self.log.append((key, "compute-task", stimulus_id, time()))
try:
ts = self.tasks[key]
logger.debug(
Expand All @@ -1912,47 +1919,14 @@ def handle_compute_task(
)
except KeyError:
self.tasks[key] = ts = TaskState(key)

ts.run_spec = SerializedTask(function, args, kwargs, task)

assert isinstance(priority, tuple)
priority = priority + (self.generation,)
self.generation -= 1

if actor:
self.actors[ts.key] = None

ts.exception = None
ts.traceback = None
ts.exception_text = ""
ts.traceback_text = ""
ts.priority = priority
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations
self.log.append((key, "compute-task", ts.state, stimulus_id, time()))

recommendations: Recs = {}
instructions: Instructions = []
for dependency in who_has:
dep_ts = self.ensure_task_exists(
key=dependency,
priority=priority,
stimulus_id=stimulus_id,
)

# link up to child / parents
ts.dependencies.add(dep_ts)
dep_ts.dependents.add(ts)

if nbytes is not None:
for key, value in nbytes.items():
self.tasks[key].nbytes = value

if ts.state in READY | {"executing", "waiting", "resumed"}:
if ts.state in READY | {"executing", "long-running", "waiting", "resumed"}:
pass
elif ts.state == "memory":
recommendations[ts] = "memory"
instructions.append(
self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
)
Expand All @@ -1965,12 +1939,56 @@ def handle_compute_task(
"error",
}:
recommendations[ts] = "waiting"
else: # pragma: no cover

ts.run_spec = SerializedTask(function, args, kwargs, task)

assert isinstance(priority, tuple)
priority = priority + (self.generation,)
self.generation -= 1

if actor:
self.actors[ts.key] = None

ts.exception = None
ts.traceback = None
ts.exception_text = ""
ts.traceback_text = ""
ts.priority = priority
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations

if self.validate:
assert who_has.keys() == nbytes.keys()
assert all(who_has.values())

for dep_key, dep_workers in who_has.items():
dep_ts = self.ensure_task_exists(
key=dep_key,
priority=priority,
stimulus_id=stimulus_id,
)
# link up to child / parents
ts.dependencies.add(dep_ts)
dep_ts.dependents.add(ts)

for dep_key, value in nbytes.items():
self.tasks[dep_key].nbytes = value

self.update_who_has(who_has)
else: # pragma: nocover
raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}")

self._handle_instructions(instructions)
self.update_who_has(who_has)
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
# All previously unknown tasks that were created above by
# ensure_tasks_exists() have been transitioned to fetch or flight
assert all(
ts2.state != "released" for ts2 in (ts, *ts.dependencies)
), self.story(ts, *ts.dependencies)

########################
# Worker State Machine #
Expand Down Expand Up @@ -3429,7 +3447,6 @@ async def find_missing(self) -> None:
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
who_has = {k: v for k, v in who_has.items() if v}
self.update_who_has(who_has)
recommendations: Recs = {}
for ts in self._missing_dep_flight:
Expand Down

0 comments on commit 79d5a77

Please sign in to comment.