Skip to content

Commit

Permalink
Revert "Revert "Limit incoming data transfers by amount of data" (#6994
Browse files Browse the repository at this point in the history
…)" (#7007)
  • Loading branch information
fjetter authored Sep 6, 2022
1 parent b133009 commit 94d0c1d
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 19 deletions.
17 changes: 14 additions & 3 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -508,13 +508,21 @@ properties:
donate data and only nodes below 45% will receive them. This helps
avoid data from bouncing around the cluster repeatedly.
transfer:
oneOf:
- {type: number, minimum: 0, maximum: 1}
- {enum: [false]}
description: >-
When the total size of incoming data transfers gets above this amount,
we start throttling incoming data transfers
target:
oneOf:
- {type: number, minimum: 0, maximum: 1}
- {enum: [false]}
description: >-
When the process memory (as observed by the operating system) gets
above this amount we start spilling the dask keys holding the largest
above this amount, we start spilling the dask keys holding the oldest
chunks of data to disk
spill:
Expand All @@ -523,15 +531,18 @@ properties:
- {enum: [false]}
description: >-
When the process memory (as observed by the operating system) gets
above this amount we spill all data to disk.
above this amount, we spill data to disk, starting from the dask keys
holding the oldest chunks of data, until the process memory falls below
the target threshold.
pause:
oneOf:
- {type: number, minimum: 0, maximum: 1}
- {enum: [false]}
description: >-
When the process memory (as observed by the operating system) gets
above this amount we no longer start new tasks on this worker.
above this amount, we no longer start new tasks or fetch new
data on the worker.
terminate:
oneOf:
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ distributed:

# Fractions of worker process memory at which we take action to avoid memory
# blowup. Set any of the values to False to turn off the behavior entirely.
# All fractions are relative to each worker's memory_limit.
transfer: 0.10 # fractional size of incoming data transfers where we start
# throttling incoming data transfers
target: 0.60 # fraction of managed memory where we start spilling to disk
spill: 0.70 # fraction of process memory where we start spilling to disk
pause: 0.80 # fraction of process memory at which we pause worker threads
Expand Down
187 changes: 187 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import gc
import pickle
from collections import defaultdict
from collections.abc import Iterator

import pytest
Expand Down Expand Up @@ -1351,3 +1352,189 @@ def test_transfer_incoming_metrics(ws):
assert ws.transfer_incoming_bytes == 0
assert ws.transfer_incoming_count == 0
assert ws.transfer_incoming_count_total == 4


def test_throttling_does_not_affect_first_transfer(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 100
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"c",
who_has={"a": [ws2]},
nbytes={"a": 200},
stimulus_id="s1",
)
)
assert ws.tasks["a"].state == "flight"


def test_throttle_incoming_transfers_on_count_limit(ws):
ws.transfer_incoming_count_limit = 1
ws.transfer_incoming_bytes_limit = 100_000
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
who_has = {"a": [ws2], "b": [ws3]}
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"c",
who_has=who_has,
nbytes={"a": 100, "b": 100},
stimulus_id="s1",
)
)
tasks_by_state = defaultdict(list)
for ts in ws.tasks.values():
tasks_by_state[ts.state].append(ts)
assert len(tasks_by_state["flight"]) == 1
assert len(tasks_by_state["fetch"]) == 1
assert ws.transfer_incoming_bytes == 100

in_flight_task = tasks_by_state["flight"][0]
ws.handle_stimulus(
GatherDepSuccessEvent(
worker=who_has[in_flight_task.key][0],
data={in_flight_task.key: 123},
total_nbytes=100,
stimulus_id="s2",
)
)
assert tasks_by_state["flight"][0].state == "memory"
assert tasks_by_state["fetch"][0].state == "flight"
assert ws.transfer_incoming_bytes == 100


def test_throttling_incoming_transfer_on_transfer_bytes_same_worker(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 250
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"d",
who_has={"a": [ws2], "b": [ws2], "c": [ws2]},
nbytes={"a": 100, "b": 100, "c": 100},
stimulus_id="s1",
)
)
tasks_by_state = defaultdict(list)
for ts in ws.tasks.values():
tasks_by_state[ts.state].append(ts)
assert ws.transfer_incoming_bytes == 200
assert len(tasks_by_state["flight"]) == 2
assert len(tasks_by_state["fetch"]) == 1

ws.handle_stimulus(
GatherDepSuccessEvent(
worker=ws2,
data={ts.key: 123 for ts in tasks_by_state["flight"]},
total_nbytes=200,
stimulus_id="s2",
)
)
assert all(ts.state == "memory" for ts in tasks_by_state["flight"])
assert all(ts.state == "flight" for ts in tasks_by_state["fetch"])


def test_throttling_incoming_transfer_on_transfer_bytes_different_workers(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 150
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
who_has = {"a": [ws2], "b": [ws3]}
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"c",
who_has=who_has,
nbytes={"a": 100, "b": 100},
stimulus_id="s1",
)
)
tasks_by_state = defaultdict(list)
for ts in ws.tasks.values():
tasks_by_state[ts.state].append(ts)
assert ws.transfer_incoming_bytes == 100
assert len(tasks_by_state["flight"]) == 1
assert len(tasks_by_state["fetch"]) == 1

in_flight_task = tasks_by_state["flight"][0]
ws.handle_stimulus(
GatherDepSuccessEvent(
worker=who_has[in_flight_task.key][0],
data={in_flight_task.key: 123},
total_nbytes=100,
stimulus_id="s2",
)
)
assert tasks_by_state["flight"][0].state == "memory"
assert tasks_by_state["fetch"][0].state == "flight"


def test_do_not_throttle_connections_while_below_threshold(ws):
ws.transfer_incoming_count_limit = 1
ws.transfer_incoming_bytes_limit = 200
ws.transfer_incoming_bytes_throttle_threshold = 50
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
ws4 = "127.0.0.1:4"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"b",
who_has={"a": [ws2]},
nbytes={"a": 1},
stimulus_id="s1",
)
)
assert ws.tasks["a"].state == "flight"

ws.handle_stimulus(
ComputeTaskEvent.dummy(
"d",
who_has={"c": [ws3]},
nbytes={"c": 1},
stimulus_id="s2",
)
)
assert ws.tasks["c"].state == "flight"

ws.handle_stimulus(
ComputeTaskEvent.dummy(
"f",
who_has={"e": [ws4]},
nbytes={"e": 100},
stimulus_id="s3",
)
)
assert ws.tasks["e"].state == "flight"
assert ws.transfer_incoming_bytes == 102


def test_throttle_on_transfer_bytes_regardless_of_threshold(ws):
ws.transfer_incoming_count_limit = 1
ws.transfer_incoming_bytes_limit = 100
ws.transfer_incoming_bytes_throttle_threshold = 50
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"b",
who_has={"a": [ws2]},
nbytes={"a": 1},
stimulus_id="s1",
)
)
assert ws.tasks["a"].state == "flight"

ws.handle_stimulus(
ComputeTaskEvent.dummy(
"d",
who_has={"c": [ws3]},
nbytes={"c": 100},
stimulus_id="s2",
)
)
assert ws.tasks["c"].state == "fetch"
assert ws.transfer_incoming_bytes == 1
13 changes: 13 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,18 @@ def __init__(
memory_spill_fraction=memory_spill_fraction,
memory_pause_fraction=memory_pause_fraction,
)

transfer_incoming_bytes_limit = None
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
if (
self.memory_manager.memory_limit is not None
and transfer_incoming_bytes_fraction is not False
):
transfer_incoming_bytes_limit = int(
self.memory_manager.memory_limit * transfer_incoming_bytes_fraction
)
state = WorkerState(
nthreads=nthreads,
data=self.memory_manager.data,
Expand All @@ -754,6 +766,7 @@ def __init__(
transfer_incoming_count_limit=transfer_incoming_count_limit,
validate=validate,
transition_counter_max=transition_counter_max,
transfer_incoming_bytes_limit=transfer_incoming_bytes_limit,
)
BaseWorker.__init__(self, state)

Expand Down
Loading

0 comments on commit 94d0c1d

Please sign in to comment.