diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 1e6daf54253..a110aaa8078 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -508,21 +508,13 @@ properties: donate data and only nodes below 45% will receive them. This helps avoid data from bouncing around the cluster repeatedly. - transfer: - oneOf: - - {type: number, minimum: 0, maximum: 1} - - {enum: [false]} - description: >- - When the total size of incoming data transfers gets above this amount, - we start throttling incoming data transfers - target: oneOf: - {type: number, minimum: 0, maximum: 1} - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount, we start spilling the dask keys holding the oldest + above this amount we start spilling the dask keys holding the largest chunks of data to disk spill: @@ -531,9 +523,7 @@ properties: - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount, we spill data to disk, starting from the dask keys - holding the oldest chunks of data, until the process memory falls below - the target threshold. + above this amount we spill all data to disk. pause: oneOf: @@ -541,8 +531,7 @@ properties: - {enum: [false]} description: >- When the process memory (as observed by the operating system) gets - above this amount, we no longer start new tasks or fetch new - data on the worker. + above this amount we no longer start new tasks on this worker. terminate: oneOf: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index d3a93a7a651..ca8292146f4 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -143,9 +143,6 @@ distributed: # Fractions of worker process memory at which we take action to avoid memory # blowup. Set any of the values to False to turn off the behavior entirely. - # All fractions are relative to each worker's memory_limit. - transfer: 0.10 # fractional size of incoming data transfers where we start - # throttling incoming data transfers target: 0.60 # fraction of managed memory where we start spilling to disk spill: 0.70 # fraction of process memory where we start spilling to disk pause: 0.80 # fraction of process memory at which we pause worker threads diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 2f1861252c4..f2c8372c7f8 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -3,7 +3,6 @@ import asyncio import gc import pickle -from collections import defaultdict from collections.abc import Iterator import pytest @@ -1352,189 +1351,3 @@ def test_transfer_incoming_metrics(ws): assert ws.transfer_incoming_bytes == 0 assert ws.transfer_incoming_count == 0 assert ws.transfer_incoming_count_total == 4 - - -def test_throttling_does_not_affect_first_transfer(ws): - ws.transfer_incoming_count_limit = 100 - ws.transfer_incoming_bytes_limit = 100 - ws.transfer_incoming_bytes_throttle_threshold = 1 - ws2 = "127.0.0.1:2" - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "c", - who_has={"a": [ws2]}, - nbytes={"a": 200}, - stimulus_id="s1", - ) - ) - assert ws.tasks["a"].state == "flight" - - -def test_throttle_incoming_transfers_on_count_limit(ws): - ws.transfer_incoming_count_limit = 1 - ws.transfer_incoming_bytes_limit = 100_000 - ws.transfer_incoming_bytes_throttle_threshold = 1 - ws2 = "127.0.0.1:2" - ws3 = "127.0.0.1:3" - who_has = {"a": [ws2], "b": [ws3]} - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "c", - who_has=who_has, - nbytes={"a": 100, "b": 100}, - stimulus_id="s1", - ) - ) - tasks_by_state = defaultdict(list) - for ts in ws.tasks.values(): - tasks_by_state[ts.state].append(ts) - assert len(tasks_by_state["flight"]) == 1 - assert len(tasks_by_state["fetch"]) == 1 - assert ws.transfer_incoming_bytes == 100 - - in_flight_task = tasks_by_state["flight"][0] - ws.handle_stimulus( - GatherDepSuccessEvent( - worker=who_has[in_flight_task.key][0], - data={in_flight_task.key: 123}, - total_nbytes=100, - stimulus_id="s2", - ) - ) - assert tasks_by_state["flight"][0].state == "memory" - assert tasks_by_state["fetch"][0].state == "flight" - assert ws.transfer_incoming_bytes == 100 - - -def test_throttling_incoming_transfer_on_transfer_bytes_same_worker(ws): - ws.transfer_incoming_count_limit = 100 - ws.transfer_incoming_bytes_limit = 250 - ws.transfer_incoming_bytes_throttle_threshold = 1 - ws2 = "127.0.0.1:2" - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "d", - who_has={"a": [ws2], "b": [ws2], "c": [ws2]}, - nbytes={"a": 100, "b": 100, "c": 100}, - stimulus_id="s1", - ) - ) - tasks_by_state = defaultdict(list) - for ts in ws.tasks.values(): - tasks_by_state[ts.state].append(ts) - assert ws.transfer_incoming_bytes == 200 - assert len(tasks_by_state["flight"]) == 2 - assert len(tasks_by_state["fetch"]) == 1 - - ws.handle_stimulus( - GatherDepSuccessEvent( - worker=ws2, - data={ts.key: 123 for ts in tasks_by_state["flight"]}, - total_nbytes=200, - stimulus_id="s2", - ) - ) - assert all(ts.state == "memory" for ts in tasks_by_state["flight"]) - assert all(ts.state == "flight" for ts in tasks_by_state["fetch"]) - - -def test_throttling_incoming_transfer_on_transfer_bytes_different_workers(ws): - ws.transfer_incoming_count_limit = 100 - ws.transfer_incoming_bytes_limit = 150 - ws.transfer_incoming_bytes_throttle_threshold = 1 - ws2 = "127.0.0.1:2" - ws3 = "127.0.0.1:3" - who_has = {"a": [ws2], "b": [ws3]} - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "c", - who_has=who_has, - nbytes={"a": 100, "b": 100}, - stimulus_id="s1", - ) - ) - tasks_by_state = defaultdict(list) - for ts in ws.tasks.values(): - tasks_by_state[ts.state].append(ts) - assert ws.transfer_incoming_bytes == 100 - assert len(tasks_by_state["flight"]) == 1 - assert len(tasks_by_state["fetch"]) == 1 - - in_flight_task = tasks_by_state["flight"][0] - ws.handle_stimulus( - GatherDepSuccessEvent( - worker=who_has[in_flight_task.key][0], - data={in_flight_task.key: 123}, - total_nbytes=100, - stimulus_id="s2", - ) - ) - assert tasks_by_state["flight"][0].state == "memory" - assert tasks_by_state["fetch"][0].state == "flight" - - -def test_do_not_throttle_connections_while_below_threshold(ws): - ws.transfer_incoming_count_limit = 1 - ws.transfer_incoming_bytes_limit = 200 - ws.transfer_incoming_bytes_throttle_threshold = 50 - ws2 = "127.0.0.1:2" - ws3 = "127.0.0.1:3" - ws4 = "127.0.0.1:4" - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "b", - who_has={"a": [ws2]}, - nbytes={"a": 1}, - stimulus_id="s1", - ) - ) - assert ws.tasks["a"].state == "flight" - - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "d", - who_has={"c": [ws3]}, - nbytes={"c": 1}, - stimulus_id="s2", - ) - ) - assert ws.tasks["c"].state == "flight" - - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "f", - who_has={"e": [ws4]}, - nbytes={"e": 100}, - stimulus_id="s3", - ) - ) - assert ws.tasks["e"].state == "flight" - assert ws.transfer_incoming_bytes == 102 - - -def test_throttle_on_transfer_bytes_regardless_of_threshold(ws): - ws.transfer_incoming_count_limit = 1 - ws.transfer_incoming_bytes_limit = 100 - ws.transfer_incoming_bytes_throttle_threshold = 50 - ws2 = "127.0.0.1:2" - ws3 = "127.0.0.1:3" - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "b", - who_has={"a": [ws2]}, - nbytes={"a": 1}, - stimulus_id="s1", - ) - ) - assert ws.tasks["a"].state == "flight" - - ws.handle_stimulus( - ComputeTaskEvent.dummy( - "d", - who_has={"c": [ws3]}, - nbytes={"c": 100}, - stimulus_id="s2", - ) - ) - assert ws.tasks["c"].state == "fetch" - assert ws.transfer_incoming_bytes == 1 diff --git a/distributed/worker.py b/distributed/worker.py index 9764de061bf..3d4b98da019 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -745,18 +745,6 @@ def __init__( memory_spill_fraction=memory_spill_fraction, memory_pause_fraction=memory_pause_fraction, ) - - transfer_incoming_bytes_limit = None - transfer_incoming_bytes_fraction = dask.config.get( - "distributed.worker.memory.transfer" - ) - if ( - self.memory_manager.memory_limit is not None - and transfer_incoming_bytes_fraction is not False - ): - transfer_incoming_bytes_limit = int( - self.memory_manager.memory_limit * transfer_incoming_bytes_fraction - ) state = WorkerState( nthreads=nthreads, data=self.memory_manager.data, @@ -766,7 +754,6 @@ def __init__( transfer_incoming_count_limit=transfer_incoming_count_limit, validate=validate, transition_counter_max=transition_counter_max, - transfer_incoming_bytes_limit=transfer_incoming_bytes_limit, ) BaseWorker.__init__(self, state) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8080c03b626..fec249c5abb 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1232,9 +1232,6 @@ class WorkerState: #: In production, it should always be set to False. transition_counter_max: int | Literal[False] - #: Limit of bytes for incoming data transfers; this is used for throttling. - transfer_incoming_bytes_limit: int | None - #: Statically-seeded random state, used to guarantee determinism whenever a #: pseudo-random choice is required rng: random.Random @@ -1253,7 +1250,6 @@ def __init__( transfer_incoming_count_limit: int = 9999, validate: bool = True, transition_counter_max: int | Literal[False] = False, - transfer_incoming_bytes_limit: int | None = None, ): self.nthreads = nthreads @@ -1297,7 +1293,6 @@ def __init__( self.stimulus_log = deque(maxlen=10_000) self.transition_counter = 0 self.transition_counter_max = transition_counter_max - self.transfer_incoming_bytes_limit = transfer_incoming_bytes_limit self.actors = {} self.rng = random.Random(0) @@ -1473,36 +1468,17 @@ def _purge_state(self, ts: TaskState) -> None: self.long_running.discard(ts) self.in_flight_tasks.discard(ts) - def _should_throttle_incoming_transfers(self) -> bool: - """Decides whether the WorkerState should throttle data transfers from other workers. - - Returns - ------- - * True if the number of incoming data transfers reached its limit - and the size of incoming data transfers reached the minimum threshold for throttling - * True if the size of incoming data transfers reached its limit - * False otherwise - """ - reached_count_limit = ( - self.transfer_incoming_count >= self.transfer_incoming_count_limit - ) - reached_throttle_threshold = ( - self.transfer_incoming_bytes - >= self.transfer_incoming_bytes_throttle_threshold - ) - reached_bytes_limit = ( - self.transfer_incoming_bytes_limit is not None - and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit - ) - return reached_count_limit and reached_throttle_threshold or reached_bytes_limit - def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: """Transition tasks from fetch to flight, until there are no more tasks in fetch state or a threshold has been reached. """ if not self.running or not self.data_needed: return {}, [] - if self._should_throttle_incoming_transfers(): + if ( + self.transfer_incoming_count >= self.transfer_incoming_count_limit + and self.transfer_incoming_bytes + >= self.transfer_incoming_bytes_throttle_threshold + ): return {}, [] recommendations: Recs = {} @@ -1513,12 +1489,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: to_gather_tasks, total_nbytes = self._select_keys_for_gather( available_tasks ) - # We always load at least one task - assert to_gather_tasks or self.transfer_incoming_bytes - # ...but that task might be selected in the previous iteration of the loop - if not to_gather_tasks: - break - + assert to_gather_tasks to_gather_keys = {ts.key for ts in to_gather_tasks} logger.debug( @@ -1560,7 +1531,11 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: self.in_flight_workers[worker] = to_gather_keys self.transfer_incoming_count_total += 1 self.transfer_incoming_bytes += total_nbytes - if self._should_throttle_incoming_transfers(): + if ( + self.transfer_incoming_count >= self.transfer_incoming_count_limit + and self.transfer_incoming_bytes + >= self.transfer_incoming_bytes_throttle_threshold + ): break return recommendations, instructions @@ -1635,28 +1610,18 @@ def _select_keys_for_gather( """Helper of _ensure_communicating. Fetch all tasks that are replicated on the target worker within a single - message, up to transfer_message_target_bytes or until we reach the limit - for the size of incoming data transfers. + message, up to transfer_message_target_bytes. """ to_gather: list[TaskState] = [] total_nbytes = 0 - if self.transfer_incoming_bytes_limit is not None: - bytes_left_to_fetch = min( - self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes, - self.transfer_message_target_bytes, - ) - else: - bytes_left_to_fetch = self.transfer_message_target_bytes - while available: ts = available.peek() + # The top-priority task is fetched regardless of its size if ( - # When there is no other traffic, the top-priority task is fetched - # regardless of its size to ensure progress - self.transfer_incoming_bytes - or to_gather - ) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch: + to_gather + and total_nbytes + ts.get_nbytes() > self.transfer_message_target_bytes + ): break for worker in ts.who_has: # This also effectively pops from available