Skip to content

Commit

Permalink
Allow pausing and choke event loop while spilling (#6189)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Apr 27, 2022
1 parent 6fb095e commit 6d1c68a
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 3 deletions.
115 changes: 114 additions & 1 deletion distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import asyncio
import logging
from collections import UserDict
import threading
from collections import Counter, UserDict
from time import sleep

import pytest
Expand All @@ -12,6 +13,7 @@
import distributed.system
from distributed import Client, Event, Nanny, Worker, wait
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import has_zict_210
from distributed.utils_test import captured_logger, gen_cluster, inc
from distributed.worker_memory import parse_memory_limit
Expand Down Expand Up @@ -636,6 +638,117 @@ def leak():
assert "memory" in out.lower()


@gen_cluster(
nthreads=[("", 1)],
client=True,
worker_kwargs={"memory_limit": "10 GiB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.7,
"distributed.worker.memory.pause": 0.9,
"distributed.worker.memory.monitor-interval": "10ms",
},
)
async def test_pause_while_spilling(c, s, a):
N_PAUSE = 3
N_TOTAL = 5

def get_process_memory():
if len(a.data) < N_PAUSE:
# Don't trigger spilling until after all tasks have completed
return 0
elif a.data.fast and not a.data.slow:
# Trigger spilling
return 8 * 2**30
else:
# Trigger pause, but only after we started spilling
return 10 * 2**30

a.monitor.get_process_memory = get_process_memory

class SlowSpill:
def __init__(self, _):
# Can't pickle a Semaphore, so instead of a default value, we create it
# here. Don't worry about race conditions; the worker is single-threaded.
if not hasattr(type(self), "sem"):
type(self).sem = threading.Semaphore(N_PAUSE)
# Block if there are N_PAUSE tasks in a.data.fast
self.sem.acquire()

def __reduce__(self):
paused = distributed.get_worker().status == Status.paused
if not paused:
sleep(0.1)
self.sem.release()
return bool, (paused,)

futs = c.map(SlowSpill, range(N_TOTAL))
while len(a.data.slow) < N_PAUSE + 1:
await asyncio.sleep(0.01)

assert a.status == Status.paused
# Worker should have become paused after the first `SlowSpill` was evicted, because
# the spill to disk took longer than the memory monitor interval.
assert len(a.data.fast) == 0
assert len(a.data.slow) == N_PAUSE + 1
n_spilled_while_paused = sum(paused is True for paused in a.data.slow.values())
assert N_PAUSE <= n_spilled_while_paused <= N_PAUSE + 1


@pytest.mark.slow
@gen_cluster(
nthreads=[("", 1)],
client=True,
worker_kwargs={"memory_limit": "10 GiB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.6,
"distributed.worker.memory.pause": False,
"distributed.worker.memory.monitor-interval": "10ms",
},
)
async def test_release_evloop_while_spilling(c, s, a):
N = 100

def get_process_memory():
if len(a.data) < N:
# Don't trigger spilling until after all tasks have completed
return 0
return 10 * 2**30

a.monitor.get_process_memory = get_process_memory

class SlowSpill:
def __reduce__(self):
sleep(0.01)
return SlowSpill, ()

futs = [c.submit(SlowSpill, pure=False) for _ in range(N)]
while len(a.data) < N:
await asyncio.sleep(0)

ts = [monotonic()]
while a.data.fast:
await asyncio.sleep(0)
ts.append(monotonic())

# 100 tasks taking 0.01s to pickle each = 2s to spill everything
# (this is because everything is pickled twice:
# https://github.com/dask/distributed/issues/1371).
# We should regain control of the event loop every 0.5s.
c = Counter(round(t1 - t0, 1) for t0, t1 in zip(ts, ts[1:]))
# Depending on the implementation of WorkerMemoryMonitor._maybe_spill:
# if it calls sleep(0) every 0.5s:
# {0.0: 315, 0.5: 4}
# if it calls sleep(0) after spilling each key:
# {0.0: 233}
# if it never yields:
# {0.0: 359, 2.0: 1}
# Make sure we remain in the first use case.
assert 1 < sum(v for k, v in c.items() if 0.5 <= k <= 1.9), dict(c)
assert not any(v for k, v in c.items() if k >= 2.0), dict(c)


@pytest.mark.parametrize(
"cls,name,value",
[
Expand Down
22 changes: 20 additions & 2 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from distributed import system
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import ManualEvictProto, SpillBuffer
from distributed.utils import log_errors
from distributed.utils_perf import ThrottledGC
Expand Down Expand Up @@ -234,6 +235,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
)
count = 0
need = memory - target
last_checked_for_pause = last_yielded = monotonic()

while memory > target:
if not data.fast:
logger.warning(
Expand All @@ -255,7 +258,6 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:

total_spilled += weight
count += 1
await asyncio.sleep(0)

memory = worker.monitor.get_process_memory()
if total_spilled > need and memory > target:
Expand All @@ -265,7 +267,23 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
self._throttled_gc.collect()
memory = worker.monitor.get_process_memory()

self._maybe_pause_or_unpause(worker, memory)
now = monotonic()

# Spilling may potentially take multiple seconds; we may pass the pause
# threshold in the meantime.
if now - last_checked_for_pause > self.memory_monitor_interval:
self._maybe_pause_or_unpause(worker, memory)
last_checked_for_pause = now

# Increase spilling aggressiveness when the fast buffer is filled with a lot
# of small values. This artificially chokes the rest of the event loop -
# namely, the reception of new data from other workers. While this is
# somewhat of an ugly hack, DO NOT tweak this without a thorough cycle of
# stress testing. See: https://github.com/dask/distributed/issues/6110.
if now - last_yielded > 0.5:
await asyncio.sleep(0)
last_yielded = monotonic()

if count:
logger.debug(
"Moved %d tasks worth %s to disk",
Expand Down

0 comments on commit 6d1c68a

Please sign in to comment.