diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index eaec94725c7..9b46a14b6e7 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1010,30 +1010,40 @@ async def test_deprecated_worker_attributes(s, a, b): assert a.data_needed == set() +@pytest.mark.parametrize("n_remote_workers", [1, 2]) @pytest.mark.parametrize( - "nbytes,n_in_flight", + "nbytes,n_in_flight_per_worker", [ (int(10e6), 3), (int(20e6), 2), (int(30e6), 1), + (int(60e6), 1), ], ) -def test_aggregate_gather_deps(ws, nbytes, n_in_flight): +def test_aggregate_gather_deps(ws, nbytes, n_in_flight_per_worker, n_remote_workers): ws.transfer_message_bytes_limit = int(50e6) - ws2 = "127.0.0.1:2" + wss = [f"127.0.0.1:{2 + i}" for i in range(n_remote_workers)] + who_has = {f"x{i}": [wss[i // 3]] for i in range(3 * n_remote_workers)} instructions = ws.handle_stimulus( AcquireReplicasEvent( - who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]}, - nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes}, + who_has=who_has, + nbytes={task: nbytes for task in who_has.keys()}, stimulus_id="s1", ) ) - assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")] - assert len(instructions[0].to_gather) == n_in_flight - assert len(ws.in_flight_tasks) == n_in_flight - assert ws.transfer_incoming_bytes == nbytes * n_in_flight - assert ws.transfer_incoming_count == 1 - assert ws.transfer_incoming_count_total == 1 + assert instructions == [ + GatherDep.match(worker=remote, stimulus_id="s1") for remote in wss + ] + assert all( + len(instruction.to_gather) == n_in_flight_per_worker + for instruction in instructions + ) + assert len(ws.in_flight_tasks) == n_in_flight_per_worker * n_remote_workers + assert ( + ws.transfer_incoming_bytes == nbytes * n_in_flight_per_worker * n_remote_workers + ) + assert ws.transfer_incoming_count == n_remote_workers + assert ws.transfer_incoming_count_total == n_remote_workers def test_gather_priority(ws): @@ -1358,6 +1368,7 @@ def test_throttling_does_not_affect_first_transfer(ws): ws.transfer_incoming_count_limit = 100 ws.transfer_incoming_bytes_limit = 100 ws.transfer_incoming_bytes_throttle_threshold = 1 + ws.transfer_message_bytes_limit = 100 ws2 = "127.0.0.1:2" ws.handle_stimulus( ComputeTaskEvent.dummy( @@ -1370,6 +1381,25 @@ def test_throttling_does_not_affect_first_transfer(ws): assert ws.tasks["a"].state == "flight" +def test_message_target_does_not_affect_first_transfer_on_different_worker(ws): + ws.transfer_incoming_count_limit = 100 + ws.transfer_incoming_bytes_limit = 600 + ws.transfer_message_bytes_limit = 100 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "c", + who_has={"a": [ws2], "b": [ws3]}, + nbytes={"a": 200, "b": 200}, + stimulus_id="s1", + ) + ) + assert ws.tasks["a"].state == "flight" + assert ws.tasks["b"].state == "flight" + + def test_throttle_incoming_transfers_on_count_limit(ws): ws.transfer_incoming_count_limit = 1 ws.transfer_incoming_bytes_limit = 100_000 diff --git a/distributed/worker.py b/distributed/worker.py index 29ef7939f6f..3c3a16b5b66 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -6,6 +6,7 @@ import errno import functools import logging +import math import os import pathlib import random @@ -748,7 +749,7 @@ def __init__( memory_pause_fraction=memory_pause_fraction, ) - transfer_incoming_bytes_limit = None + transfer_incoming_bytes_limit = math.inf transfer_incoming_bytes_fraction = dask.config.get( "distributed.worker.memory.transfer" ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 30b39f8d051..6add2f72d94 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1234,7 +1234,7 @@ class WorkerState: transition_counter_max: int | Literal[False] #: Limit of bytes for incoming data transfers; this is used for throttling. - transfer_incoming_bytes_limit: int | None + transfer_incoming_bytes_limit: float #: Statically-seeded random state, used to guarantee determinism whenever a #: pseudo-random choice is required @@ -1254,7 +1254,7 @@ def __init__( transfer_incoming_count_limit: int = 9999, validate: bool = True, transition_counter_max: int | Literal[False] = False, - transfer_incoming_bytes_limit: int | None = None, + transfer_incoming_bytes_limit: float = math.inf, transfer_message_bytes_limit: float = math.inf, ): self.nthreads = nthreads @@ -1493,8 +1493,7 @@ def _should_throttle_incoming_transfers(self) -> bool: >= self.transfer_incoming_bytes_throttle_threshold ) reached_bytes_limit = ( - self.transfer_incoming_bytes_limit is not None - and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit + self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit ) return reached_count_limit and reached_throttle_threshold or reached_bytes_limit @@ -1512,7 +1511,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: for worker, available_tasks in self._select_workers_for_gather(): assert worker != self.address - to_gather_tasks, total_nbytes = self._select_keys_for_gather( + to_gather_tasks, message_nbytes = self._select_keys_for_gather( available_tasks ) # We always load at least one task @@ -1554,14 +1553,14 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: GatherDep( worker=worker, to_gather=to_gather_keys, - total_nbytes=total_nbytes, + total_nbytes=message_nbytes, stimulus_id=stimulus_id, ) ) self.in_flight_workers[worker] = to_gather_keys self.transfer_incoming_count_total += 1 - self.transfer_incoming_bytes += total_nbytes + self.transfer_incoming_bytes += message_nbytes if self._should_throttle_incoming_transfers(): break @@ -1641,32 +1640,57 @@ def _select_keys_for_gather( for the size of incoming data transfers. """ to_gather: list[TaskState] = [] - total_nbytes = 0 - - if self.transfer_incoming_bytes_limit is not None: - bytes_left_to_fetch = min( - self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes, - self.transfer_message_bytes_limit, - ) - else: - bytes_left_to_fetch = self.transfer_message_bytes_limit + message_nbytes = 0 while available: ts = available.peek() - if ( - # When there is no other traffic, the top-priority task is fetched - # regardless of its size to ensure progress - self.transfer_incoming_bytes - or to_gather - ) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch: + if self._task_exceeds_transfer_limits(ts, message_nbytes): break for worker in ts.who_has: # This also effectively pops from available self.data_needed[worker].remove(ts) to_gather.append(ts) - total_nbytes += ts.get_nbytes() + message_nbytes += ts.get_nbytes() + + return to_gather, message_nbytes + + def _task_exceeds_transfer_limits(self, ts: TaskState, message_nbytes: int) -> bool: + """Would asking to gather this task exceed transfer limits? + + Parameters + ---------- + ts + Candidate task for gathering + message_nbytes + Total number of bytes already scheduled for gathering in this message + Returns + ------- + exceeds_limit + True if gathering the task would exceed limits, False otherwise + (in which case the task can be gathered). + """ + if self.transfer_incoming_bytes == 0 and message_nbytes == 0: + # When there is no other traffic, the top-priority task is fetched + # regardless of its size to ensure progress + return False + + incoming_bytes_allowance = ( + self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes + ) + + # If message_nbytes == 0, i.e., this is the first task to gather in this + # message, ignore `self.transfer_message_bytes_limit` for the top-priority + # task to ensure progress. Otherwise: + if message_nbytes != 0: + incoming_bytes_allowance = ( + min( + incoming_bytes_allowance, + self.transfer_message_bytes_limit, + ) + - message_nbytes + ) - return to_gather, total_nbytes + return ts.get_nbytes() > incoming_bytes_allowance def _ensure_computing(self) -> RecsInstrs: if not self.running: