Skip to content

Commit

Permalink
Fix transfer limiting in _select_keys_for_gather (dask#7071)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored and gjoseph92 committed Oct 31, 2022
1 parent 7d0c0e8 commit fce949a
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 36 deletions.
52 changes: 41 additions & 11 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,30 +1010,40 @@ async def test_deprecated_worker_attributes(s, a, b):
assert a.data_needed == set()


@pytest.mark.parametrize("n_remote_workers", [1, 2])
@pytest.mark.parametrize(
"nbytes,n_in_flight",
"nbytes,n_in_flight_per_worker",
[
(int(10e6), 3),
(int(20e6), 2),
(int(30e6), 1),
(int(60e6), 1),
],
)
def test_aggregate_gather_deps(ws, nbytes, n_in_flight):
def test_aggregate_gather_deps(ws, nbytes, n_in_flight_per_worker, n_remote_workers):
ws.transfer_message_bytes_limit = int(50e6)
ws2 = "127.0.0.1:2"
wss = [f"127.0.0.1:{2 + i}" for i in range(n_remote_workers)]
who_has = {f"x{i}": [wss[i // 3]] for i in range(3 * n_remote_workers)}
instructions = ws.handle_stimulus(
AcquireReplicasEvent(
who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]},
nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes},
who_has=who_has,
nbytes={task: nbytes for task in who_has.keys()},
stimulus_id="s1",
)
)
assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")]
assert len(instructions[0].to_gather) == n_in_flight
assert len(ws.in_flight_tasks) == n_in_flight
assert ws.transfer_incoming_bytes == nbytes * n_in_flight
assert ws.transfer_incoming_count == 1
assert ws.transfer_incoming_count_total == 1
assert instructions == [
GatherDep.match(worker=remote, stimulus_id="s1") for remote in wss
]
assert all(
len(instruction.to_gather) == n_in_flight_per_worker
for instruction in instructions
)
assert len(ws.in_flight_tasks) == n_in_flight_per_worker * n_remote_workers
assert (
ws.transfer_incoming_bytes == nbytes * n_in_flight_per_worker * n_remote_workers
)
assert ws.transfer_incoming_count == n_remote_workers
assert ws.transfer_incoming_count_total == n_remote_workers


def test_gather_priority(ws):
Expand Down Expand Up @@ -1358,6 +1368,7 @@ def test_throttling_does_not_affect_first_transfer(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 100
ws.transfer_incoming_bytes_throttle_threshold = 1
ws.transfer_message_bytes_limit = 100
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
Expand All @@ -1370,6 +1381,25 @@ def test_throttling_does_not_affect_first_transfer(ws):
assert ws.tasks["a"].state == "flight"


def test_message_target_does_not_affect_first_transfer_on_different_worker(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 600
ws.transfer_message_bytes_limit = 100
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"c",
who_has={"a": [ws2], "b": [ws3]},
nbytes={"a": 200, "b": 200},
stimulus_id="s1",
)
)
assert ws.tasks["a"].state == "flight"
assert ws.tasks["b"].state == "flight"


def test_throttle_incoming_transfers_on_count_limit(ws):
ws.transfer_incoming_count_limit = 1
ws.transfer_incoming_bytes_limit = 100_000
Expand Down
3 changes: 2 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import errno
import functools
import logging
import math
import os
import pathlib
import random
Expand Down Expand Up @@ -748,7 +749,7 @@ def __init__(
memory_pause_fraction=memory_pause_fraction,
)

transfer_incoming_bytes_limit = None
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
Expand Down
72 changes: 48 additions & 24 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,7 @@ class WorkerState:
transition_counter_max: int | Literal[False]

#: Limit of bytes for incoming data transfers; this is used for throttling.
transfer_incoming_bytes_limit: int | None
transfer_incoming_bytes_limit: float

#: Statically-seeded random state, used to guarantee determinism whenever a
#: pseudo-random choice is required
Expand All @@ -1254,7 +1254,7 @@ def __init__(
transfer_incoming_count_limit: int = 9999,
validate: bool = True,
transition_counter_max: int | Literal[False] = False,
transfer_incoming_bytes_limit: int | None = None,
transfer_incoming_bytes_limit: float = math.inf,
transfer_message_bytes_limit: float = math.inf,
):
self.nthreads = nthreads
Expand Down Expand Up @@ -1493,8 +1493,7 @@ def _should_throttle_incoming_transfers(self) -> bool:
>= self.transfer_incoming_bytes_throttle_threshold
)
reached_bytes_limit = (
self.transfer_incoming_bytes_limit is not None
and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit
self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit
)
return reached_count_limit and reached_throttle_threshold or reached_bytes_limit

Expand All @@ -1512,7 +1511,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:

for worker, available_tasks in self._select_workers_for_gather():
assert worker != self.address
to_gather_tasks, total_nbytes = self._select_keys_for_gather(
to_gather_tasks, message_nbytes = self._select_keys_for_gather(
available_tasks
)
# We always load at least one task
Expand Down Expand Up @@ -1554,14 +1553,14 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
GatherDep(
worker=worker,
to_gather=to_gather_keys,
total_nbytes=total_nbytes,
total_nbytes=message_nbytes,
stimulus_id=stimulus_id,
)
)

self.in_flight_workers[worker] = to_gather_keys
self.transfer_incoming_count_total += 1
self.transfer_incoming_bytes += total_nbytes
self.transfer_incoming_bytes += message_nbytes
if self._should_throttle_incoming_transfers():
break

Expand Down Expand Up @@ -1641,32 +1640,57 @@ def _select_keys_for_gather(
for the size of incoming data transfers.
"""
to_gather: list[TaskState] = []
total_nbytes = 0

if self.transfer_incoming_bytes_limit is not None:
bytes_left_to_fetch = min(
self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes,
self.transfer_message_bytes_limit,
)
else:
bytes_left_to_fetch = self.transfer_message_bytes_limit
message_nbytes = 0

while available:
ts = available.peek()
if (
# When there is no other traffic, the top-priority task is fetched
# regardless of its size to ensure progress
self.transfer_incoming_bytes
or to_gather
) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch:
if self._task_exceeds_transfer_limits(ts, message_nbytes):
break
for worker in ts.who_has:
# This also effectively pops from available
self.data_needed[worker].remove(ts)
to_gather.append(ts)
total_nbytes += ts.get_nbytes()
message_nbytes += ts.get_nbytes()

return to_gather, message_nbytes

def _task_exceeds_transfer_limits(self, ts: TaskState, message_nbytes: int) -> bool:
"""Would asking to gather this task exceed transfer limits?
Parameters
----------
ts
Candidate task for gathering
message_nbytes
Total number of bytes already scheduled for gathering in this message
Returns
-------
exceeds_limit
True if gathering the task would exceed limits, False otherwise
(in which case the task can be gathered).
"""
if self.transfer_incoming_bytes == 0 and message_nbytes == 0:
# When there is no other traffic, the top-priority task is fetched
# regardless of its size to ensure progress
return False

incoming_bytes_allowance = (
self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes
)

# If message_nbytes == 0, i.e., this is the first task to gather in this
# message, ignore `self.transfer_message_bytes_limit` for the top-priority
# task to ensure progress. Otherwise:
if message_nbytes != 0:
incoming_bytes_allowance = (
min(
incoming_bytes_allowance,
self.transfer_message_bytes_limit,
)
- message_nbytes
)

return to_gather, total_nbytes
return ts.get_nbytes() > incoming_bytes_allowance

def _ensure_computing(self) -> RecsInstrs:
if not self.running:
Expand Down

0 comments on commit fce949a

Please sign in to comment.