Skip to content

Commit

Permalink
Refactor missing-data command (#6332)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored May 13, 2022
1 parent 3bedd8e commit 50d2911
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 50 deletions.
57 changes: 26 additions & 31 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4683,53 +4683,48 @@ def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None:
self.send_all(client_msgs, worker_msgs)

def handle_missing_data(
self, key: str, errant_worker: str, stimulus_id: str, **kwargs
self, key: str, worker: str, errant_worker: str, stimulus_id: str
) -> None:
"""Signal that `errant_worker` does not hold `key`
"""Signal that `errant_worker` does not hold `key`.
This may either indicate that `errant_worker` is dead or that we may be
working with stale data and need to remove `key` from the workers
`has_what`.
If no replica of a task is available anymore, the task is transitioned
back to released and rescheduled, if possible.
This may either indicate that `errant_worker` is dead or that we may be working
with stale data and need to remove `key` from the workers `has_what`. If no
replica of a task is available anymore, the task is transitioned back to
released and rescheduled, if possible.
Parameters
----------
key : str, optional
Task key that could not be found, by default None
errant_worker : str, optional
Address of the worker supposed to hold a replica, by default None
key : str
Task key that could not be found
worker : str
Address of the worker informing the scheduler
errant_worker : str
Address of the worker supposed to hold a replica
"""
logger.debug("handle missing data key=%s worker=%s", key, errant_worker)
logger.debug(f"handle missing data {key=} {worker=} {errant_worker=}")
self.log_event(errant_worker, {"action": "missing-data", "key": key})

if key not in self.tasks:
ts = self.tasks.get(key)
ws = self.workers.get(errant_worker)
if not ts or not ws or ws not in ts.who_has:
return

ts: TaskState = self.tasks[key]
ws: WorkerState = self.workers.get(errant_worker)

if ws is not None and ws in ts.who_has:
self.remove_replica(ts, ws)
self.remove_replica(ts, ws)
if ts.state == "memory" and not ts.who_has:
if ts.run_spec:
self.transitions({key: "released"}, stimulus_id)
else:
self.transitions({key: "forgotten"}, stimulus_id)

def release_worker_data(self, key, worker, stimulus_id):
ws: WorkerState = self.workers.get(worker)
ts: TaskState = self.tasks.get(key)
if not ws or not ts:
def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None:
ts = self.tasks.get(key)
ws = self.workers.get(worker)
if not ts or not ws or ws not in ts.who_has:
return
recommendations: dict = {}
if ws in ts.who_has:
self.remove_replica(ts, ws)
if not ts.who_has:
recommendations[ts.key] = "released"
if recommendations:
self.transitions(recommendations, stimulus_id)

self.remove_replica(ts, ws)
if not ts.who_has:
self.transitions({key: "released"}, stimulus_id)

def handle_long_running(self, key=None, worker=None, compute_duration=None):
"""A task has seceded from the thread pool
Expand Down Expand Up @@ -4907,7 +4902,7 @@ async def register_scheduler_plugin(self, plugin, name=None, idempotent=None):

self.add_plugin(plugin, name=name, idempotent=idempotent)

def worker_send(self, worker, msg):
def worker_send(self, worker: str, msg: dict[str, Any]) -> None:
"""Send message to worker
This also handles connection failures by adding a callback to remove
Expand Down
6 changes: 4 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2696,15 +2696,17 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b):
See also test_gather_dep_exception_one_task
"""
# This test does not trigger the condition reliably but is a very easy case
# which should function correctly regardles
# which should function correctly regardless

fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
fut2 = c.submit(inc, fut1, workers=[b.address], key="f2")

while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight":
await asyncio.sleep(0)

s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test")
s.handle_missing_data(
key="f1", worker=b.address, errant_worker=a.address, stimulus_id="test"
)

await fut2

Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def test_sendmsg_to_dict():
def test_merge_recs_instructions():
x = TaskState("x")
y = TaskState("y")
instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test")
instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test")
instr1 = RescheduleMsg(key="foo", stimulus_id="test")
instr2 = RescheduleMsg(key="bar", stimulus_id="test")
assert merge_recs_instructions(
({x: "memory"}, [instr1]),
({y: "released"}, [instr2]),
Expand Down
24 changes: 11 additions & 13 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
Instructions,
InvalidTransition,
LongRunningMsg,
MissingDataMsg,
Recs,
RecsInstrs,
ReleaseWorkerDataMsg,
Expand Down Expand Up @@ -2145,7 +2146,7 @@ def transition_long_running_rescheduled(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
recs: Recs = {ts: "released"}
smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id)
smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)
return recs, [smsg]

def transition_executing_rescheduled(
Expand All @@ -2158,11 +2159,7 @@ def transition_executing_rescheduled(
return merge_recs_instructions(
(
{ts: "released"},
[
RescheduleMsg(
key=ts.key, worker=self.address, stimulus_id=stimulus_id
)
],
[RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)],
),
self._ensure_computing(),
)
Expand Down Expand Up @@ -3285,6 +3282,7 @@ async def gather_dep(
return None

recommendations: Recs = {}
instructions: Instructions = []
response = {}
to_gather_keys: set[str] = set()
cancelled_keys: set[str] = set()
Expand Down Expand Up @@ -3406,17 +3404,17 @@ def done_event():
ts.who_has.discard(worker)
self.has_what[worker].discard(ts.key)
self.log.append((d, "missing-dep", stimulus_id, time()))
self.batched_stream.send(
{
"op": "missing-data",
"errant_worker": worker,
"key": d,
"stimulus_id": stimulus_id,
}
instructions.append(
MissingDataMsg(
key=d,
errant_worker=worker,
stimulus_id=stimulus_id,
)
)
recommendations[ts] = "fetch"
del data, response
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if refresh_who_has:
# All workers that hold known replicas of our tasks are busy.
Expand Down
13 changes: 11 additions & 2 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,23 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler):
stimulus_id: str


@dataclass
class MissingDataMsg(SendMessageToScheduler):
op = "missing-data"

__slots__ = ("key", "errant_worker", "stimulus_id")
key: str
errant_worker: str
stimulus_id: str


# Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception
@dataclass
class RescheduleMsg(SendMessageToScheduler):
op = "reschedule"

__slots__ = ("key", "worker", "stimulus_id")
__slots__ = ("key", "stimulus_id")
key: str
worker: str
stimulus_id: str


Expand Down

0 comments on commit 50d2911

Please sign in to comment.