diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index b810c84d9ec..37ac35893c6 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -609,7 +609,6 @@ class ShuffleSchedulerExtension(SchedulerPlugin): participating_workers: dict[ShuffleId, set[str]] tombstones: set[ShuffleId] erred_shuffles: dict[ShuffleId, Exception] - barriers: dict[ShuffleId, str] def __init__(self, scheduler: Scheduler): self.scheduler = scheduler @@ -629,7 +628,6 @@ def __init__(self, scheduler: Scheduler): self.participating_workers = {} self.tombstones = set() self.erred_shuffles = {} - self.barriers = {} self.scheduler.add_plugin(self) def shuffle_ids(self) -> set[ShuffleId]: @@ -646,7 +644,7 @@ def barrier_key(cls, shuffle_id: ShuffleId) -> str: @classmethod def id_from_key(cls, key: str) -> ShuffleId: - assert "shuffle-barrier-" in key + assert key.startswith("shuffle-barrier-") return ShuffleId(key.replace("shuffle-barrier-", "")) def get( @@ -674,7 +672,6 @@ def get( output_workers = set() name = self.barrier_key(id) - self.barriers[id] = name mapping = {} for ts in self.scheduler.tasks[name].dependents: @@ -724,7 +721,7 @@ async def remove_worker(self, scheduler: Scheduler, worker: str) -> None: contact_workers = shuffle_workers.copy() contact_workers.discard(worker) affected_shuffles.add(shuffle_id) - name = self.barriers[shuffle_id] + name = self.barrier_key(shuffle_id) barrier_task = self.scheduler.tasks.get(name) if barrier_task: barriers.append(barrier_task) @@ -769,11 +766,12 @@ def transition( ) -> None: if finish != "forgotten": return - if key not in self.barriers.values(): - + if not key.startswith("shuffle-barrier-"): + return + shuffle_id = self.id_from_key(key) + if shuffle_id not in self.worker_for: return - shuffle_id = ShuffleSchedulerExtension.id_from_key(key) participating_workers = self.participating_workers[shuffle_id] worker_msgs = { worker: [ @@ -806,7 +804,6 @@ def _clean_on_scheduler(self, id: ShuffleId) -> None: del self.completed_workers[id] del self.participating_workers[id] self.erred_shuffles.pop(id, None) - del self.barriers[id] with contextlib.suppress(KeyError): del self.heartbeats[id]