Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix transfer limiting in _select_keys_for_gather #7071

Merged
merged 12 commits into from
Sep 27, 2022
54 changes: 42 additions & 12 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,30 +1010,40 @@ async def test_deprecated_worker_attributes(s, a, b):
assert a.data_needed == set()


@pytest.mark.parametrize("n_remote_workers", [1, 2])
@pytest.mark.parametrize(
"nbytes,n_in_flight",
"nbytes,n_in_flight_per_worker",
[
# Note: transfer_message_target_bytes = 50e6 bytes
(int(10e6), 3),
(int(20e6), 2),
(int(30e6), 1),
(int(60e6), 1),
],
)
def test_aggregate_gather_deps(ws, nbytes, n_in_flight):
ws2 = "127.0.0.1:2"
def test_aggregate_gather_deps(ws, nbytes, n_in_flight_per_worker, n_remote_workers):
ws.transfer_message_target_bytes = int(50e6)
wss = [f"127.0.0.1:{2 + i}" for i in range(n_remote_workers)]
who_has = {f"x{i}": [wss[i // 3]] for i in range(3 * n_remote_workers)}
instructions = ws.handle_stimulus(
AcquireReplicasEvent(
who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]},
nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes},
who_has=who_has,
nbytes={task: nbytes for task in who_has.keys()},
stimulus_id="s1",
)
)
assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")]
assert len(instructions[0].to_gather) == n_in_flight
assert len(ws.in_flight_tasks) == n_in_flight
assert ws.transfer_incoming_bytes == nbytes * n_in_flight
assert ws.transfer_incoming_count == 1
assert ws.transfer_incoming_count_total == 1
assert instructions == [
GatherDep.match(worker=remote, stimulus_id="s1") for remote in wss
]
assert all(
len(instruction.to_gather) == n_in_flight_per_worker
for instruction in instructions
)
assert len(ws.in_flight_tasks) == n_in_flight_per_worker * n_remote_workers
assert (
ws.transfer_incoming_bytes == nbytes * n_in_flight_per_worker * n_remote_workers
)
assert ws.transfer_incoming_count == n_remote_workers
assert ws.transfer_incoming_count_total == n_remote_workers


def test_gather_priority(ws):
Expand Down Expand Up @@ -1358,6 +1368,7 @@ def test_throttling_does_not_affect_first_transfer(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 100
ws.transfer_incoming_bytes_throttle_threshold = 1
ws.transfer_message_target_bytes = 100
ws2 = "127.0.0.1:2"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
Expand All @@ -1370,6 +1381,25 @@ def test_throttling_does_not_affect_first_transfer(ws):
assert ws.tasks["a"].state == "flight"


def test_message_target_does_not_affect_first_transfer_on_different_worker(ws):
ws.transfer_incoming_count_limit = 100
ws.transfer_incoming_bytes_limit = 600
ws.transfer_message_target_bytes = 100
ws.transfer_incoming_bytes_throttle_threshold = 1
ws2 = "127.0.0.1:2"
ws3 = "127.0.0.1:3"
ws.handle_stimulus(
ComputeTaskEvent.dummy(
"c",
who_has={"a": [ws2], "b": [ws3]},
nbytes={"a": 200, "b": 200},
stimulus_id="s1",
)
)
assert ws.tasks["a"].state == "flight"
assert ws.tasks["b"].state == "flight"


def test_throttle_incoming_transfers_on_count_limit(ws):
ws.transfer_incoming_count_limit = 1
ws.transfer_incoming_bytes_limit = 100_000
Expand Down
3 changes: 2 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import errno
import functools
import logging
import math
import os
import pathlib
import random
Expand Down Expand Up @@ -746,7 +747,7 @@ def __init__(
memory_pause_fraction=memory_pause_fraction,
)

transfer_incoming_bytes_limit = None
transfer_incoming_bytes_limit = math.inf
transfer_incoming_bytes_fraction = dask.config.get(
"distributed.worker.memory.transfer"
)
Expand Down
73 changes: 49 additions & 24 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import heapq
import logging
import math
import operator
import random
import sys
Expand Down Expand Up @@ -1233,7 +1234,7 @@ class WorkerState:
transition_counter_max: int | Literal[False]

#: Limit of bytes for incoming data transfers; this is used for throttling.
transfer_incoming_bytes_limit: int | None
transfer_incoming_bytes_limit: float

#: Statically-seeded random state, used to guarantee determinism whenever a
#: pseudo-random choice is required
Expand All @@ -1253,7 +1254,7 @@ def __init__(
transfer_incoming_count_limit: int = 9999,
validate: bool = True,
transition_counter_max: int | Literal[False] = False,
transfer_incoming_bytes_limit: int | None = None,
transfer_incoming_bytes_limit: float = math.inf,
):
self.nthreads = nthreads

Expand Down Expand Up @@ -1491,8 +1492,7 @@ def _should_throttle_incoming_transfers(self) -> bool:
>= self.transfer_incoming_bytes_throttle_threshold
)
reached_bytes_limit = (
self.transfer_incoming_bytes_limit is not None
and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit
self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit
)
return reached_count_limit and reached_throttle_threshold or reached_bytes_limit

Expand All @@ -1510,7 +1510,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:

for worker, available_tasks in self._select_workers_for_gather():
assert worker != self.address
to_gather_tasks, total_nbytes = self._select_keys_for_gather(
to_gather_tasks, message_nbytes = self._select_keys_for_gather(
available_tasks
)
# We always load at least one task
Expand Down Expand Up @@ -1552,14 +1552,14 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs:
GatherDep(
worker=worker,
to_gather=to_gather_keys,
total_nbytes=total_nbytes,
total_nbytes=message_nbytes,
stimulus_id=stimulus_id,
)
)

self.in_flight_workers[worker] = to_gather_keys
self.transfer_incoming_count_total += 1
self.transfer_incoming_bytes += total_nbytes
self.transfer_incoming_bytes += message_nbytes
if self._should_throttle_incoming_transfers():
break

Expand Down Expand Up @@ -1639,32 +1639,57 @@ def _select_keys_for_gather(
for the size of incoming data transfers.
"""
to_gather: list[TaskState] = []
total_nbytes = 0

if self.transfer_incoming_bytes_limit is not None:
bytes_left_to_fetch = min(
self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes,
self.transfer_message_target_bytes,
)
else:
bytes_left_to_fetch = self.transfer_message_target_bytes
message_nbytes = 0

while available:
ts = available.peek()
if (
# When there is no other traffic, the top-priority task is fetched
# regardless of its size to ensure progress
self.transfer_incoming_bytes
or to_gather
) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch:
if self._task_exceeds_transfer_limits(ts, message_nbytes):
break
for worker in ts.who_has:
# This also effectively pops from available
self.data_needed[worker].remove(ts)
to_gather.append(ts)
total_nbytes += ts.get_nbytes()
message_nbytes += ts.get_nbytes()

return to_gather, message_nbytes

def _task_exceeds_transfer_limits(self, ts: TaskState, message_nbytes: int) -> bool:
"""Would asking to gather this task exceed transfer limits?

Parameters
----------
ts
Candidate task for gathering
message_nbytes
Total number of bytes already scheduled for gathering in this message
Returns
-------
exceeds_limit
True if gathering the task would exceed limits, False otherwise
(in which case the task can be gathered).
"""
if self.transfer_incoming_bytes == 0 and message_nbytes == 0:
# When there is no other traffic, the top-priority task is fetched
# regardless of its size to ensure progress
return False

incoming_bytes_allowance = (
self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes
)

# If message_nbytes == 0, i.e., this is the first task to gather in this
# message, ignore `self.transfer_message_target_bytes` for the top-priority
# task to ensure progress. Otherwise:
if message_nbytes != 0:
incoming_bytes_allowance = (
min(
incoming_bytes_allowance,
self.transfer_message_target_bytes,
)
- message_nbytes
)

return to_gather, total_nbytes
return ts.get_nbytes() > incoming_bytes_allowance

def _ensure_computing(self) -> RecsInstrs:
if not self.running:
Expand Down