From 6d1c68a25b753e872fe8806da61918c2b49b0c9a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 27 Apr 2022 14:53:04 +0100 Subject: [PATCH] Allow pausing and choke event loop while spilling (#6189) --- distributed/tests/test_worker_memory.py | 115 +++++++++++++++++++++++- distributed/worker_memory.py | 22 ++++- 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 0c9e88c12db..7698a2bddc2 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -2,7 +2,8 @@ import asyncio import logging -from collections import UserDict +import threading +from collections import Counter, UserDict from time import sleep import pytest @@ -12,6 +13,7 @@ import distributed.system from distributed import Client, Event, Nanny, Worker, wait from distributed.core import Status +from distributed.metrics import monotonic from distributed.spill import has_zict_210 from distributed.utils_test import captured_logger, gen_cluster, inc from distributed.worker_memory import parse_memory_limit @@ -636,6 +638,117 @@ def leak(): assert "memory" in out.lower() +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "10 GiB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": 0.9, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_pause_while_spilling(c, s, a): + N_PAUSE = 3 + N_TOTAL = 5 + + def get_process_memory(): + if len(a.data) < N_PAUSE: + # Don't trigger spilling until after all tasks have completed + return 0 + elif a.data.fast and not a.data.slow: + # Trigger spilling + return 8 * 2**30 + else: + # Trigger pause, but only after we started spilling + return 10 * 2**30 + + a.monitor.get_process_memory = get_process_memory + + class SlowSpill: + def __init__(self, _): + # Can't pickle a Semaphore, so instead of a default value, we create it + # here. Don't worry about race conditions; the worker is single-threaded. + if not hasattr(type(self), "sem"): + type(self).sem = threading.Semaphore(N_PAUSE) + # Block if there are N_PAUSE tasks in a.data.fast + self.sem.acquire() + + def __reduce__(self): + paused = distributed.get_worker().status == Status.paused + if not paused: + sleep(0.1) + self.sem.release() + return bool, (paused,) + + futs = c.map(SlowSpill, range(N_TOTAL)) + while len(a.data.slow) < N_PAUSE + 1: + await asyncio.sleep(0.01) + + assert a.status == Status.paused + # Worker should have become paused after the first `SlowSpill` was evicted, because + # the spill to disk took longer than the memory monitor interval. + assert len(a.data.fast) == 0 + assert len(a.data.slow) == N_PAUSE + 1 + n_spilled_while_paused = sum(paused is True for paused in a.data.slow.values()) + assert N_PAUSE <= n_spilled_while_paused <= N_PAUSE + 1 + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "10 GiB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.6, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_release_evloop_while_spilling(c, s, a): + N = 100 + + def get_process_memory(): + if len(a.data) < N: + # Don't trigger spilling until after all tasks have completed + return 0 + return 10 * 2**30 + + a.monitor.get_process_memory = get_process_memory + + class SlowSpill: + def __reduce__(self): + sleep(0.01) + return SlowSpill, () + + futs = [c.submit(SlowSpill, pure=False) for _ in range(N)] + while len(a.data) < N: + await asyncio.sleep(0) + + ts = [monotonic()] + while a.data.fast: + await asyncio.sleep(0) + ts.append(monotonic()) + + # 100 tasks taking 0.01s to pickle each = 2s to spill everything + # (this is because everything is pickled twice: + # https://github.com/dask/distributed/issues/1371). + # We should regain control of the event loop every 0.5s. + c = Counter(round(t1 - t0, 1) for t0, t1 in zip(ts, ts[1:])) + # Depending on the implementation of WorkerMemoryMonitor._maybe_spill: + # if it calls sleep(0) every 0.5s: + # {0.0: 315, 0.5: 4} + # if it calls sleep(0) after spilling each key: + # {0.0: 233} + # if it never yields: + # {0.0: 359, 2.0: 1} + # Make sure we remain in the first use case. + assert 1 < sum(v for k, v in c.items() if 0.5 <= k <= 1.9), dict(c) + assert not any(v for k, v in c.items() if k >= 2.0), dict(c) + + @pytest.mark.parametrize( "cls,name,value", [ diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index d4ef636e9f9..904e4899dfd 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -39,6 +39,7 @@ from distributed import system from distributed.core import Status +from distributed.metrics import monotonic from distributed.spill import ManualEvictProto, SpillBuffer from distributed.utils import log_errors from distributed.utils_perf import ThrottledGC @@ -234,6 +235,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: ) count = 0 need = memory - target + last_checked_for_pause = last_yielded = monotonic() + while memory > target: if not data.fast: logger.warning( @@ -255,7 +258,6 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: total_spilled += weight count += 1 - await asyncio.sleep(0) memory = worker.monitor.get_process_memory() if total_spilled > need and memory > target: @@ -265,7 +267,23 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: self._throttled_gc.collect() memory = worker.monitor.get_process_memory() - self._maybe_pause_or_unpause(worker, memory) + now = monotonic() + + # Spilling may potentially take multiple seconds; we may pass the pause + # threshold in the meantime. + if now - last_checked_for_pause > self.memory_monitor_interval: + self._maybe_pause_or_unpause(worker, memory) + last_checked_for_pause = now + + # Increase spilling aggressiveness when the fast buffer is filled with a lot + # of small values. This artificially chokes the rest of the event loop - + # namely, the reception of new data from other workers. While this is + # somewhat of an ugly hack, DO NOT tweak this without a thorough cycle of + # stress testing. See: https://github.com/dask/distributed/issues/6110. + if now - last_yielded > 0.5: + await asyncio.sleep(0) + last_yielded = monotonic() + if count: logger.debug( "Moved %d tasks worth %s to disk",