diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index a110aaa8078..1e6daf54253 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -508,13 +508,21 @@ properties: donate data and only nodes below 45% will receive them. This helps avoid data from bouncing around the cluster repeatedly. + transfer: + oneOf: + - {type: number, minimum: 0, maximum: 1} + - {enum: [false]} + description: >- + When the total size of incoming data transfers gets above this amount, + we start throttling incoming data transfers + target: oneOf: - {type: number, minimum: 0, maximum: 1} - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount we start spilling the dask keys holding the largest + above this amount, we start spilling the dask keys holding the oldest chunks of data to disk spill: @@ -523,7 +531,9 @@ properties: - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount we spill all data to disk. + above this amount, we spill data to disk, starting from the dask keys + holding the oldest chunks of data, until the process memory falls below + the target threshold. pause: oneOf: @@ -531,7 +541,8 @@ properties: - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount we no longer start new tasks on this worker. + above this amount, we no longer start new tasks or fetch new + data on the worker. terminate: oneOf: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ca8292146f4..d3a93a7a651 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -143,6 +143,9 @@ distributed: # Fractions of worker process memory at which we take action to avoid memory # blowup. Set any of the values to False to turn off the behavior entirely. + # All fractions are relative to each worker's memory_limit. + transfer: 0.10 # fractional size of incoming data transfers where we start + # throttling incoming data transfers target: 0.60 # fraction of managed memory where we start spilling to disk spill: 0.70 # fraction of process memory where we start spilling to disk pause: 0.80 # fraction of process memory at which we pause worker threads diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index f2c8372c7f8..2f1861252c4 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -3,6 +3,7 @@ import asyncio import gc import pickle +from collections import defaultdict from collections.abc import Iterator import pytest @@ -1351,3 +1352,189 @@ def test_transfer_incoming_metrics(ws): assert ws.transfer_incoming_bytes == 0 assert ws.transfer_incoming_count == 0 assert ws.transfer_incoming_count_total == 4 + + +def test_throttling_does_not_affect_first_transfer(ws): + ws.transfer_incoming_count_limit = 100 + ws.transfer_incoming_bytes_limit = 100 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "c", + who_has={"a": [ws2]}, + nbytes={"a": 200}, + stimulus_id="s1", + ) + ) + assert ws.tasks["a"].state == "flight" + + +def test_throttle_incoming_transfers_on_count_limit(ws): + ws.transfer_incoming_count_limit = 1 + ws.transfer_incoming_bytes_limit = 100_000 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + who_has = {"a": [ws2], "b": [ws3]} + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "c", + who_has=who_has, + nbytes={"a": 100, "b": 100}, + stimulus_id="s1", + ) + ) + tasks_by_state = defaultdict(list) + for ts in ws.tasks.values(): + tasks_by_state[ts.state].append(ts) + assert len(tasks_by_state["flight"]) == 1 + assert len(tasks_by_state["fetch"]) == 1 + assert ws.transfer_incoming_bytes == 100 + + in_flight_task = tasks_by_state["flight"][0] + ws.handle_stimulus( + GatherDepSuccessEvent( + worker=who_has[in_flight_task.key][0], + data={in_flight_task.key: 123}, + total_nbytes=100, + stimulus_id="s2", + ) + ) + assert tasks_by_state["flight"][0].state == "memory" + assert tasks_by_state["fetch"][0].state == "flight" + assert ws.transfer_incoming_bytes == 100 + + +def test_throttling_incoming_transfer_on_transfer_bytes_same_worker(ws): + ws.transfer_incoming_count_limit = 100 + ws.transfer_incoming_bytes_limit = 250 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "d", + who_has={"a": [ws2], "b": [ws2], "c": [ws2]}, + nbytes={"a": 100, "b": 100, "c": 100}, + stimulus_id="s1", + ) + ) + tasks_by_state = defaultdict(list) + for ts in ws.tasks.values(): + tasks_by_state[ts.state].append(ts) + assert ws.transfer_incoming_bytes == 200 + assert len(tasks_by_state["flight"]) == 2 + assert len(tasks_by_state["fetch"]) == 1 + + ws.handle_stimulus( + GatherDepSuccessEvent( + worker=ws2, + data={ts.key: 123 for ts in tasks_by_state["flight"]}, + total_nbytes=200, + stimulus_id="s2", + ) + ) + assert all(ts.state == "memory" for ts in tasks_by_state["flight"]) + assert all(ts.state == "flight" for ts in tasks_by_state["fetch"]) + + +def test_throttling_incoming_transfer_on_transfer_bytes_different_workers(ws): + ws.transfer_incoming_count_limit = 100 + ws.transfer_incoming_bytes_limit = 150 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + who_has = {"a": [ws2], "b": [ws3]} + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "c", + who_has=who_has, + nbytes={"a": 100, "b": 100}, + stimulus_id="s1", + ) + ) + tasks_by_state = defaultdict(list) + for ts in ws.tasks.values(): + tasks_by_state[ts.state].append(ts) + assert ws.transfer_incoming_bytes == 100 + assert len(tasks_by_state["flight"]) == 1 + assert len(tasks_by_state["fetch"]) == 1 + + in_flight_task = tasks_by_state["flight"][0] + ws.handle_stimulus( + GatherDepSuccessEvent( + worker=who_has[in_flight_task.key][0], + data={in_flight_task.key: 123}, + total_nbytes=100, + stimulus_id="s2", + ) + ) + assert tasks_by_state["flight"][0].state == "memory" + assert tasks_by_state["fetch"][0].state == "flight" + + +def test_do_not_throttle_connections_while_below_threshold(ws): + ws.transfer_incoming_count_limit = 1 + ws.transfer_incoming_bytes_limit = 200 + ws.transfer_incoming_bytes_throttle_threshold = 50 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + ws4 = "127.0.0.1:4" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "b", + who_has={"a": [ws2]}, + nbytes={"a": 1}, + stimulus_id="s1", + ) + ) + assert ws.tasks["a"].state == "flight" + + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "d", + who_has={"c": [ws3]}, + nbytes={"c": 1}, + stimulus_id="s2", + ) + ) + assert ws.tasks["c"].state == "flight" + + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "f", + who_has={"e": [ws4]}, + nbytes={"e": 100}, + stimulus_id="s3", + ) + ) + assert ws.tasks["e"].state == "flight" + assert ws.transfer_incoming_bytes == 102 + + +def test_throttle_on_transfer_bytes_regardless_of_threshold(ws): + ws.transfer_incoming_count_limit = 1 + ws.transfer_incoming_bytes_limit = 100 + ws.transfer_incoming_bytes_throttle_threshold = 50 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "b", + who_has={"a": [ws2]}, + nbytes={"a": 1}, + stimulus_id="s1", + ) + ) + assert ws.tasks["a"].state == "flight" + + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "d", + who_has={"c": [ws3]}, + nbytes={"c": 100}, + stimulus_id="s2", + ) + ) + assert ws.tasks["c"].state == "fetch" + assert ws.transfer_incoming_bytes == 1 diff --git a/distributed/worker.py b/distributed/worker.py index 3d4b98da019..9764de061bf 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -745,6 +745,18 @@ def __init__( memory_spill_fraction=memory_spill_fraction, memory_pause_fraction=memory_pause_fraction, ) + + transfer_incoming_bytes_limit = None + transfer_incoming_bytes_fraction = dask.config.get( + "distributed.worker.memory.transfer" + ) + if ( + self.memory_manager.memory_limit is not None + and transfer_incoming_bytes_fraction is not False + ): + transfer_incoming_bytes_limit = int( + self.memory_manager.memory_limit * transfer_incoming_bytes_fraction + ) state = WorkerState( nthreads=nthreads, data=self.memory_manager.data, @@ -754,6 +766,7 @@ def __init__( transfer_incoming_count_limit=transfer_incoming_count_limit, validate=validate, transition_counter_max=transition_counter_max, + transfer_incoming_bytes_limit=transfer_incoming_bytes_limit, ) BaseWorker.__init__(self, state) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index fec249c5abb..8080c03b626 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1232,6 +1232,9 @@ class WorkerState: #: In production, it should always be set to False. transition_counter_max: int | Literal[False] + #: Limit of bytes for incoming data transfers; this is used for throttling. + transfer_incoming_bytes_limit: int | None + #: Statically-seeded random state, used to guarantee determinism whenever a #: pseudo-random choice is required rng: random.Random @@ -1250,6 +1253,7 @@ def __init__( transfer_incoming_count_limit: int = 9999, validate: bool = True, transition_counter_max: int | Literal[False] = False, + transfer_incoming_bytes_limit: int | None = None, ): self.nthreads = nthreads @@ -1293,6 +1297,7 @@ def __init__( self.stimulus_log = deque(maxlen=10_000) self.transition_counter = 0 self.transition_counter_max = transition_counter_max + self.transfer_incoming_bytes_limit = transfer_incoming_bytes_limit self.actors = {} self.rng = random.Random(0) @@ -1468,17 +1473,36 @@ def _purge_state(self, ts: TaskState) -> None: self.long_running.discard(ts) self.in_flight_tasks.discard(ts) + def _should_throttle_incoming_transfers(self) -> bool: + """Decides whether the WorkerState should throttle data transfers from other workers. + + Returns + ------- + * True if the number of incoming data transfers reached its limit + and the size of incoming data transfers reached the minimum threshold for throttling + * True if the size of incoming data transfers reached its limit + * False otherwise + """ + reached_count_limit = ( + self.transfer_incoming_count >= self.transfer_incoming_count_limit + ) + reached_throttle_threshold = ( + self.transfer_incoming_bytes + >= self.transfer_incoming_bytes_throttle_threshold + ) + reached_bytes_limit = ( + self.transfer_incoming_bytes_limit is not None + and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit + ) + return reached_count_limit and reached_throttle_threshold or reached_bytes_limit + def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: """Transition tasks from fetch to flight, until there are no more tasks in fetch state or a threshold has been reached. """ if not self.running or not self.data_needed: return {}, [] - if ( - self.transfer_incoming_count >= self.transfer_incoming_count_limit - and self.transfer_incoming_bytes - >= self.transfer_incoming_bytes_throttle_threshold - ): + if self._should_throttle_incoming_transfers(): return {}, [] recommendations: Recs = {} @@ -1489,7 +1513,12 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: to_gather_tasks, total_nbytes = self._select_keys_for_gather( available_tasks ) - assert to_gather_tasks + # We always load at least one task + assert to_gather_tasks or self.transfer_incoming_bytes + # ...but that task might be selected in the previous iteration of the loop + if not to_gather_tasks: + break + to_gather_keys = {ts.key for ts in to_gather_tasks} logger.debug( @@ -1531,11 +1560,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: self.in_flight_workers[worker] = to_gather_keys self.transfer_incoming_count_total += 1 self.transfer_incoming_bytes += total_nbytes - if ( - self.transfer_incoming_count >= self.transfer_incoming_count_limit - and self.transfer_incoming_bytes - >= self.transfer_incoming_bytes_throttle_threshold - ): + if self._should_throttle_incoming_transfers(): break return recommendations, instructions @@ -1610,18 +1635,28 @@ def _select_keys_for_gather( """Helper of _ensure_communicating. Fetch all tasks that are replicated on the target worker within a single - message, up to transfer_message_target_bytes. + message, up to transfer_message_target_bytes or until we reach the limit + for the size of incoming data transfers. """ to_gather: list[TaskState] = [] total_nbytes = 0 + if self.transfer_incoming_bytes_limit is not None: + bytes_left_to_fetch = min( + self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes, + self.transfer_message_target_bytes, + ) + else: + bytes_left_to_fetch = self.transfer_message_target_bytes + while available: ts = available.peek() - # The top-priority task is fetched regardless of its size if ( - to_gather - and total_nbytes + ts.get_nbytes() > self.transfer_message_target_bytes - ): + # When there is no other traffic, the top-priority task is fetched + # regardless of its size to ensure progress + self.transfer_incoming_bytes + or to_gather + ) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch: break for worker in ts.who_has: # This also effectively pops from available