diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 00396165fe..52d1c254ee 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4683,53 +4683,48 @@ def handle_task_erred(self, key: str, stimulus_id: str, **msg) -> None: self.send_all(client_msgs, worker_msgs) def handle_missing_data( - self, key: str, errant_worker: str, stimulus_id: str, **kwargs + self, key: str, worker: str, errant_worker: str, stimulus_id: str ) -> None: - """Signal that `errant_worker` does not hold `key` + """Signal that `errant_worker` does not hold `key`. - This may either indicate that `errant_worker` is dead or that we may be - working with stale data and need to remove `key` from the workers - `has_what`. - - If no replica of a task is available anymore, the task is transitioned - back to released and rescheduled, if possible. + This may either indicate that `errant_worker` is dead or that we may be working + with stale data and need to remove `key` from the workers `has_what`. If no + replica of a task is available anymore, the task is transitioned back to + released and rescheduled, if possible. Parameters ---------- - key : str, optional - Task key that could not be found, by default None - errant_worker : str, optional - Address of the worker supposed to hold a replica, by default None + key : str + Task key that could not be found + worker : str + Address of the worker informing the scheduler + errant_worker : str + Address of the worker supposed to hold a replica """ - logger.debug("handle missing data key=%s worker=%s", key, errant_worker) + logger.debug(f"handle missing data {key=} {worker=} {errant_worker=}") self.log_event(errant_worker, {"action": "missing-data", "key": key}) - if key not in self.tasks: + ts = self.tasks.get(key) + ws = self.workers.get(errant_worker) + if not ts or not ws or ws not in ts.who_has: return - ts: TaskState = self.tasks[key] - ws: WorkerState = self.workers.get(errant_worker) - - if ws is not None and ws in ts.who_has: - self.remove_replica(ts, ws) + self.remove_replica(ts, ws) if ts.state == "memory" and not ts.who_has: if ts.run_spec: self.transitions({key: "released"}, stimulus_id) else: self.transitions({key: "forgotten"}, stimulus_id) - def release_worker_data(self, key, worker, stimulus_id): - ws: WorkerState = self.workers.get(worker) - ts: TaskState = self.tasks.get(key) - if not ws or not ts: + def release_worker_data(self, key: str, worker: str, stimulus_id: str) -> None: + ts = self.tasks.get(key) + ws = self.workers.get(worker) + if not ts or not ws or ws not in ts.who_has: return - recommendations: dict = {} - if ws in ts.who_has: - self.remove_replica(ts, ws) - if not ts.who_has: - recommendations[ts.key] = "released" - if recommendations: - self.transitions(recommendations, stimulus_id) + + self.remove_replica(ts, ws) + if not ts.who_has: + self.transitions({key: "released"}, stimulus_id) def handle_long_running(self, key=None, worker=None, compute_duration=None): """A task has seceded from the thread pool @@ -4907,7 +4902,7 @@ async def register_scheduler_plugin(self, plugin, name=None, idempotent=None): self.add_plugin(plugin, name=name, idempotent=idempotent) - def worker_send(self, worker, msg): + def worker_send(self, worker: str, msg: dict[str, Any]) -> None: """Send message to worker This also handles connection failures by adding a callback to remove diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2e5c0b5a37..bb99438e9f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2696,7 +2696,7 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): See also test_gather_dep_exception_one_task """ # This test does not trigger the condition reliably but is a very easy case - # which should function correctly regardles + # which should function correctly regardless fut1 = c.submit(inc, 1, workers=[a.address], key="f1") fut2 = c.submit(inc, fut1, workers=[b.address], key="f2") @@ -2704,7 +2704,9 @@ async def test_gather_dep_exception_one_task_2(c, s, a, b): while fut1.key not in b.tasks or b.tasks[fut1.key].state == "flight": await asyncio.sleep(0) - s.handle_missing_data(key="f1", errant_worker=a.address, stimulus_id="test") + s.handle_missing_data( + key="f1", worker=b.address, errant_worker=a.address, stimulus_id="test" + ) await fut2 diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4e65afe643..9295b9987d 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -133,8 +133,8 @@ def test_sendmsg_to_dict(): def test_merge_recs_instructions(): x = TaskState("x") y = TaskState("y") - instr1 = RescheduleMsg(key="foo", worker="a", stimulus_id="test") - instr2 = RescheduleMsg(key="bar", worker="b", stimulus_id="test") + instr1 = RescheduleMsg(key="foo", stimulus_id="test") + instr2 = RescheduleMsg(key="bar", stimulus_id="test") assert merge_recs_instructions( ({x: "memory"}, [instr1]), ({y: "released"}, [instr2]), diff --git a/distributed/worker.py b/distributed/worker.py index bec71904a7..9dc2e8e61e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -120,6 +120,7 @@ Instructions, InvalidTransition, LongRunningMsg, + MissingDataMsg, Recs, RecsInstrs, ReleaseWorkerDataMsg, @@ -2145,7 +2146,7 @@ def transition_long_running_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: recs: Recs = {ts: "released"} - smsg = RescheduleMsg(key=ts.key, worker=self.address, stimulus_id=stimulus_id) + smsg = RescheduleMsg(key=ts.key, stimulus_id=stimulus_id) return recs, [smsg] def transition_executing_rescheduled( @@ -2158,11 +2159,7 @@ def transition_executing_rescheduled( return merge_recs_instructions( ( {ts: "released"}, - [ - RescheduleMsg( - key=ts.key, worker=self.address, stimulus_id=stimulus_id - ) - ], + [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)], ), self._ensure_computing(), ) @@ -3285,6 +3282,7 @@ async def gather_dep( return None recommendations: Recs = {} + instructions: Instructions = [] response = {} to_gather_keys: set[str] = set() cancelled_keys: set[str] = set() @@ -3406,17 +3404,17 @@ def done_event(): ts.who_has.discard(worker) self.has_what[worker].discard(ts.key) self.log.append((d, "missing-dep", stimulus_id, time())) - self.batched_stream.send( - { - "op": "missing-data", - "errant_worker": worker, - "key": d, - "stimulus_id": stimulus_id, - } + instructions.append( + MissingDataMsg( + key=d, + errant_worker=worker, + stimulus_id=stimulus_id, + ) ) recommendations[ts] = "fetch" del data, response self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) if refresh_who_has: # All workers that hold known replicas of our tasks are busy. diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f78b34ea42..8fd4b17b03 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -346,14 +346,23 @@ class ReleaseWorkerDataMsg(SendMessageToScheduler): stimulus_id: str +@dataclass +class MissingDataMsg(SendMessageToScheduler): + op = "missing-data" + + __slots__ = ("key", "errant_worker", "stimulus_id") + key: str + errant_worker: str + stimulus_id: str + + # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @dataclass class RescheduleMsg(SendMessageToScheduler): op = "reschedule" - __slots__ = ("key", "worker", "stimulus_id") + __slots__ = ("key", "stimulus_id") key: str - worker: str stimulus_id: str