diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 030daa914d..915c8fb66c 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -14,7 +14,7 @@ from distributed.core import Status from distributed.spill import has_zict_210 from distributed.utils_test import captured_logger, gen_cluster, inc -from distributed.worker_memory import parse_memory_limit +from distributed.worker_memory import WorkerMemoryManager, parse_memory_limit requires_zict_210 = pytest.mark.skipif( not has_zict_210, @@ -22,6 +22,59 @@ ) +def get_fake_wmm_fast_static(value: float) -> type[WorkerMemoryManager]: + """Fake factory for WorkerMemoryManager for convenience + + This will set the observed process memory to be constant to ``value`` if + there is data in `data.fast`. + """ + + class FakeWMMFastStatic(WorkerMemoryManager): + def get_process_memory(self): + worker = self.worker() + if worker and worker.data.fast: + return value + else: + return 0 + + return FakeWMMFastStatic + + +def get_fake_wmm_fast_dynamic(value: float) -> type[WorkerMemoryManager]: + """Fake factory for WorkerMemoryManager for convenience + + This will set the observed process memory to be ``value`` times the number of elements in `data.fast`. + """ + + class FakeWMMFastDyn(WorkerMemoryManager): + def get_process_memory(self): + worker = self.worker() + if worker and worker.data.fast: + return value * len(worker.data.fast) + else: + return 0 + + return FakeWMMFastDyn + + +def get_fake_wmm_all_static(value: float) -> type[WorkerMemoryManager]: + """Fake factory for WorkerMemoryManager for convenience + + This will set the observed process memory to be ``value`` as long as there + is any data in the buffer + """ + + class FakeWMMAll(WorkerMemoryManager): + def get_process_memory(self): + worker = self.worker() + if worker and worker.data: + return value + else: + return 0 + + return FakeWMMAll + + def memory_monitor_running(dask_worker: Worker | Nanny) -> bool: return "memory_monitor" in dask_worker.periodic_callbacks @@ -109,7 +162,9 @@ async def test_fail_to_pickle_target_1(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)], - worker_kwargs={"memory_limit": "1 kiB"}, + worker_kwargs={ + "memory_limit": "1 kiB", + }, config={ "distributed.worker.memory.target": 0.5, "distributed.worker.memory.spill": False, @@ -142,7 +197,10 @@ async def test_fail_to_pickle_target_2(c, s, a): @gen_cluster( client=True, nthreads=[("", 1)], - worker_kwargs={"memory_limit": "1 kB"}, + worker_kwargs={ + "memory_limit": "1 kB", + "memory_manager_cls": get_fake_wmm_fast_static(701), + }, config={ "distributed.worker.memory.target": False, "distributed.worker.memory.spill": 0.7, @@ -151,7 +209,6 @@ async def test_fail_to_pickle_target_2(c, s, a): ) async def test_fail_to_pickle_spill(c, s, a): """Test failure to evict a key, triggered by the spill threshold""" - a.monitor.get_process_memory = lambda: 701 if a.data.fast else 0 with captured_logger(logging.getLogger("distributed.spill")) as logs: bad = c.submit(FailToPickle, key="bad") @@ -268,7 +325,10 @@ async def test_spill_constrained(c, s, w): @gen_cluster( nthreads=[("", 1)], client=True, - worker_kwargs={"memory_limit": "1000 MB"}, + worker_kwargs={ + "memory_limit": "1000 MB", + "memory_manager_cls": get_fake_wmm_fast_static(800_000_000), + }, config={ "distributed.worker.memory.target": False, "distributed.worker.memory.spill": 0.7, @@ -282,7 +342,6 @@ async def test_spill_spill_threshold(c, s, a): reported by sizeof(), which may be inaccurate. """ assert memory_monitor_running(a) - a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0 x = c.submit(inc, 0, key="x") while not a.data.disk: await asyncio.sleep(0.01) @@ -326,8 +385,11 @@ def __sizeof__(self): return managed with dask.config.set({"distributed.worker.memory.target": target}): - async with Worker(s.address, memory_limit="1000 MB") as a: - a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast) + async with Worker( + s.address, + memory_limit="1000 MB", + memory_manager_cls=get_fake_wmm_fast_dynamic(50_000_000), + ) as a: # Add 500MB (reported) process memory. Spilling must not happen. futures = [c.submit(C, pure=False) for _ in range(10)] diff --git a/distributed/worker.py b/distributed/worker.py index 13e5adeff0..42228f464b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -455,6 +455,7 @@ def __init__( lifetime_restart: bool | None = None, ################################### # Parameters to WorkerMemoryManager + memory_manager_cls: type[WorkerMemoryManager] = WorkerMemoryManager, memory_limit: str | float = "auto", # Allow overriding the dict-like that stores the task outputs. # This is meant for power users only. See WorkerMemoryManager for details. @@ -786,7 +787,7 @@ def __init__( for ext in extensions: ext(self) - self.memory_manager = WorkerMemoryManager( + self.memory_manager = memory_manager_cls( self, data=data, memory_limit=memory_limit, diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index f08dacdb2c..42327511bc 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -25,9 +25,9 @@ import os import sys import warnings +import weakref from collections.abc import Callable, MutableMapping from contextlib import suppress -from functools import partial from typing import TYPE_CHECKING, Any, Container, Literal, cast import psutil @@ -61,6 +61,7 @@ class WorkerMemoryManager: memory_monitor_interval: float _memory_monitoring: bool _throttled_gc: ThrottledGC + _worker: weakref.ReferenceType[Worker] def __init__( self, @@ -135,23 +136,32 @@ def __init__( ) assert isinstance(self.memory_monitor_interval, (int, float)) + self._worker = weakref.ref(worker) + if self.memory_limit and ( self.memory_spill_fraction is not False or self.memory_pause_fraction is not False ): assert self.memory_monitor_interval is not None pc = PeriodicCallback( - # Don't store worker as self.worker to avoid creating a circular - # dependency. We could have alternatively used a weakref. # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 - partial(self.memory_monitor, worker), # type: ignore + self.memory_monitor, # type: ignore self.memory_monitor_interval * 1000, ) worker.periodic_callbacks["memory_monitor"] = pc self._throttled_gc = ThrottledGC(logger=logger) - async def memory_monitor(self, worker: Worker) -> None: + def get_process_memory(self) -> int: + """Get a measure for process memory. + This can be a mock target. + """ + worker = self._worker() + if worker: + return worker.monitor.get_process_memory() + return -1 + + async def memory_monitor(self) -> None: """Track this process's memory usage and act accordingly. If process memory rises above the spill threshold (70%), start dumping data to disk until it goes below the target threshold (60%). @@ -166,16 +176,18 @@ async def memory_monitor(self, worker: Worker) -> None: # Don't use psutil directly; instead read from the same API that is used # to send info to the Scheduler (e.g. for the benefit of Active Memory # Manager) and which can be easily mocked in unit tests. - memory = worker.monitor.get_process_memory() - self._maybe_pause_or_unpause(worker, memory) - await self._maybe_spill(worker, memory) + memory = self.get_process_memory() + self._maybe_pause_or_unpause(memory) + await self._maybe_spill(memory) finally: self._memory_monitoring = False - def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: + def _maybe_pause_or_unpause(self, memory: int) -> None: if self.memory_pause_fraction is False: return - + worker = self._worker() + if not worker: + return assert self.memory_limit frac = memory / self.memory_limit # Pause worker threads if above 80% memory use @@ -205,7 +217,7 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: ) worker.status = Status.running - async def _maybe_spill(self, worker: Worker, memory: int) -> None: + async def _maybe_spill(self, memory: int) -> None: if self.memory_spill_fraction is False: return @@ -257,15 +269,15 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: count += 1 await asyncio.sleep(0) - memory = worker.monitor.get_process_memory() + memory = self.get_process_memory() if total_spilled > need and memory > target: # Issue a GC to ensure that the evicted data is actually # freed from memory and taken into account by the monitor # before trying to evict even more data. self._throttled_gc.collect() - memory = worker.monitor.get_process_memory() + memory = self.get_process_memory() - self._maybe_pause_or_unpause(worker, memory) + self._maybe_pause_or_unpause(memory) if count: logger.debug( "Moved %d tasks worth %s to disk", @@ -287,6 +299,7 @@ class NannyMemoryManager: memory_limit: int | None memory_terminate_fraction: float | Literal[False] memory_monitor_interval: float | None + _nanny: weakref.ReferenceType[Nanny] def __init__( self, @@ -302,32 +315,46 @@ def __init__( dask.config.get("distributed.worker.memory.monitor-interval"), default=None, ) + self._nanny = weakref.ref(nanny) assert isinstance(self.memory_monitor_interval, (int, float)) if self.memory_limit and self.memory_terminate_fraction is not False: pc = PeriodicCallback( - partial(self.memory_monitor, nanny), + self.memory_monitor, self.memory_monitor_interval * 1000, ) nanny.periodic_callbacks["memory_monitor"] = pc - def memory_monitor(self, nanny: Nanny) -> None: + def get_process_memory(self) -> int: + """Get a measure for process memory. + This can be a mock target. + """ + nanny = self._nanny() + if nanny: + try: + proc = nanny._psutil_process + return proc.memory_info().rss + except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): + return -1 # pragma: nocover + return -1 + + def memory_monitor(self) -> None: """Track worker's memory. Restart if it goes above terminate fraction.""" - if nanny.status != Status.running: - return # pragma: nocover - if nanny.process is None or nanny.process.process is None: - return # pragma: nocover - process = nanny.process.process - try: - proc = nanny._psutil_process - memory = proc.memory_info().rss - except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): + nanny = self._nanny() + if ( + not nanny + or nanny.status != Status.running + or nanny.process is None + or nanny.process.process is None + or self.memory_limit is None + ): return # pragma: nocover - + memory = self.get_process_memory() if memory / self.memory_limit > self.memory_terminate_fraction: logger.warning( "Worker exceeded %d%% memory budget. Restarting", 100 * self.memory_terminate_fraction, ) + process = nanny.process.process process.terminate() @@ -403,4 +430,4 @@ def __get__(self, instance: Nanny | Worker | None, owner): # This is triggered by Sphinx return None # pragma: nocover _warn_deprecated(instance, "memory_monitor") - return partial(instance.memory_manager.memory_monitor, instance) + return instance.memory_manager.memory_monitor