diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dd30424a671..e03983bde49 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -847,6 +847,14 @@ class TaskGroup: #: The result types of this TaskGroup types: set[str] + #: The worker most recently assigned a task from this group, or None when the group + #: is not identified to be root-like by `SchedulerState.decide_worker`. + last_worker: WorkerState | None + + #: If `last_worker` is not None, the number of times that worker should be assigned + #: subsequent tasks until a new worker is chosen. + last_worker_tasks_left: int + prefix: TaskPrefix | None start: float stop: float @@ -865,6 +873,8 @@ def __init__(self, name: str): self.start = 0.0 self.stop = 0.0 self.all_durations = defaultdict(float) + self.last_worker = None + self.last_worker_tasks_left = 0 def add_duration(self, action: str, start: float, stop: float) -> None: duration = stop - start @@ -1302,8 +1312,6 @@ class SchedulerState: "extensions", "host_info", "idle", - "last_root_worker", - "last_root_worker_tasks_left", "n_tasks", "queued", "resources", @@ -1375,8 +1383,6 @@ def __init__( self.total_nthreads = 0 self.total_occupancy = 0.0 self.unknown_durations: dict[str, set[TaskState]] = {} - self.last_root_worker: WorkerState | None = None - self.last_root_worker_tasks_left: int = 0 self.queued = queued self.unrunnable = unrunnable self.validate = validate @@ -1824,24 +1830,24 @@ def decide_worker_rootish_queuing_disabled( if not pool: return None - lws = self.last_root_worker + tg = ts.group + lws = tg.last_worker if not ( - lws - and self.last_root_worker_tasks_left - and self.workers.get(lws.address) is lws + lws and tg.last_worker_tasks_left and self.workers.get(lws.address) is lws ): # Last-used worker is full or unknown; pick a new worker for the next few tasks - ws = self.last_root_worker = min( - pool, key=lambda ws: len(ws.processing) / ws.nthreads - ) - # TODO better batching metric (`len(tg)` is not necessarily the total number of root tasks!) - self.last_root_worker_tasks_left = math.floor( - (len(ts.group) / self.total_nthreads) * ws.nthreads + ws = min(pool, key=partial(self.worker_objective, ts)) + tg.last_worker_tasks_left = math.floor( + (len(tg) / self.total_nthreads) * ws.nthreads ) else: ws = lws - self.last_root_worker_tasks_left -= 1 + # Record `last_worker`, or clear it on the final task + tg.last_worker = ( + ws if tg.states["released"] + tg.states["waiting"] > 1 else None + ) + tg.last_worker_tasks_left -= 1 if self.validate and ws is not None: assert self.workers.get(ws.address) is ws diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index f83e5f45bb7..f29594bf488 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -250,21 +250,6 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() -@pytest.mark.parametrize("ngroups", [1, 2, 3, 5]) -@gen_cluster( - client=True, - nthreads=[("", 1), ("", 1)], -) -async def test_decide_worker_coschedule_order_binary_op(c, s, a, b, ngroups): - roots = [[delayed(i, name=f"x-{n}-{i}") for i in range(8)] for n in range(ngroups)] - zs = [sum(rs) for rs in zip(*roots)] - - await c.gather(c.compute(zs)) - - assert not a.transfer_incoming_log, [l["keys"] for l in a.transfer_incoming_log] - assert not b.transfer_incoming_log, [l["keys"] for l in b.transfer_incoming_log] - - @pytest.mark.slow @gen_cluster( nthreads=[("", 2)] * 4,