Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dependency injection for proc memory mocks #6004

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,67 @@
from distributed.core import Status
from distributed.spill import has_zict_210
from distributed.utils_test import captured_logger, gen_cluster, inc
from distributed.worker_memory import parse_memory_limit
from distributed.worker_memory import WorkerMemoryManager, parse_memory_limit

requires_zict_210 = pytest.mark.skipif(
not has_zict_210,
reason="requires zict version >= 2.1.0",
)


def get_fake_wmm_fast_static(value: float) -> type[WorkerMemoryManager]:
"""Fake factory for WorkerMemoryManager for convenience

This will set the observed process memory to be constant to ``value`` if
there is data in `data.fast`.
"""

class FakeWMMFastStatic(WorkerMemoryManager):
def get_process_memory(self):
worker = self.worker()
if worker and worker.data.fast:
return value
else:
return 0

return FakeWMMFastStatic


def get_fake_wmm_fast_dynamic(value: float) -> type[WorkerMemoryManager]:
"""Fake factory for WorkerMemoryManager for convenience

This will set the observed process memory to be ``value`` times the number of elements in `data.fast`.
"""

class FakeWMMFastDyn(WorkerMemoryManager):
def get_process_memory(self):
worker = self.worker()
if worker and worker.data.fast:
return value * len(worker.data.fast)
else:
return 0

return FakeWMMFastDyn


def get_fake_wmm_all_static(value: float) -> type[WorkerMemoryManager]:
"""Fake factory for WorkerMemoryManager for convenience

This will set the observed process memory to be ``value`` as long as there
is any data in the buffer
"""

class FakeWMMAll(WorkerMemoryManager):
def get_process_memory(self):
worker = self.worker()
if worker and worker.data:
return value
else:
return 0

return FakeWMMAll


def memory_monitor_running(dask_worker: Worker | Nanny) -> bool:
return "memory_monitor" in dask_worker.periodic_callbacks

Expand Down Expand Up @@ -109,7 +162,9 @@ async def test_fail_to_pickle_target_1(c, s, a, b):
@gen_cluster(
client=True,
nthreads=[("", 1)],
worker_kwargs={"memory_limit": "1 kiB"},
worker_kwargs={
"memory_limit": "1 kiB",
},
config={
"distributed.worker.memory.target": 0.5,
"distributed.worker.memory.spill": False,
Expand Down Expand Up @@ -142,7 +197,10 @@ async def test_fail_to_pickle_target_2(c, s, a):
@gen_cluster(
client=True,
nthreads=[("", 1)],
worker_kwargs={"memory_limit": "1 kB"},
worker_kwargs={
"memory_limit": "1 kB",
"memory_manager_cls": get_fake_wmm_fast_static(701),
},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.7,
Expand All @@ -151,7 +209,6 @@ async def test_fail_to_pickle_target_2(c, s, a):
)
async def test_fail_to_pickle_spill(c, s, a):
"""Test failure to evict a key, triggered by the spill threshold"""
a.monitor.get_process_memory = lambda: 701 if a.data.fast else 0

with captured_logger(logging.getLogger("distributed.spill")) as logs:
bad = c.submit(FailToPickle, key="bad")
Expand Down Expand Up @@ -268,7 +325,10 @@ async def test_spill_constrained(c, s, w):
@gen_cluster(
nthreads=[("", 1)],
client=True,
worker_kwargs={"memory_limit": "1000 MB"},
worker_kwargs={
"memory_limit": "1000 MB",
"memory_manager_cls": get_fake_wmm_fast_static(800_000_000),
},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.7,
Expand All @@ -282,7 +342,6 @@ async def test_spill_spill_threshold(c, s, a):
reported by sizeof(), which may be inaccurate.
"""
assert memory_monitor_running(a)
a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0
x = c.submit(inc, 0, key="x")
while not a.data.disk:
await asyncio.sleep(0.01)
Expand Down Expand Up @@ -326,8 +385,11 @@ def __sizeof__(self):
return managed

with dask.config.set({"distributed.worker.memory.target": target}):
async with Worker(s.address, memory_limit="1000 MB") as a:
a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast)
async with Worker(
s.address,
memory_limit="1000 MB",
memory_manager_cls=get_fake_wmm_fast_dynamic(50_000_000),
) as a:

# Add 500MB (reported) process memory. Spilling must not happen.
futures = [c.submit(C, pure=False) for _ in range(10)]
Expand Down
3 changes: 2 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def __init__(
lifetime_restart: bool | None = None,
###################################
# Parameters to WorkerMemoryManager
memory_manager_cls: type[WorkerMemoryManager] = WorkerMemoryManager,
memory_limit: str | float = "auto",
# Allow overriding the dict-like that stores the task outputs.
# This is meant for power users only. See WorkerMemoryManager for details.
Expand Down Expand Up @@ -786,7 +787,7 @@ def __init__(
for ext in extensions:
ext(self)

self.memory_manager = WorkerMemoryManager(
self.memory_manager = memory_manager_cls(
self,
data=data,
memory_limit=memory_limit,
Expand Down
81 changes: 54 additions & 27 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import os
import sys
import warnings
import weakref
from collections.abc import Callable, MutableMapping
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Container, Literal, cast

import psutil
Expand Down Expand Up @@ -61,6 +61,7 @@ class WorkerMemoryManager:
memory_monitor_interval: float
_memory_monitoring: bool
_throttled_gc: ThrottledGC
_worker: weakref.ReferenceType[Worker]

def __init__(
self,
Expand Down Expand Up @@ -135,23 +136,32 @@ def __init__(
)
assert isinstance(self.memory_monitor_interval, (int, float))

self._worker = weakref.ref(worker)

if self.memory_limit and (
self.memory_spill_fraction is not False
or self.memory_pause_fraction is not False
):
assert self.memory_monitor_interval is not None
pc = PeriodicCallback(
# Don't store worker as self.worker to avoid creating a circular
# dependency. We could have alternatively used a weakref.
# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
partial(self.memory_monitor, worker), # type: ignore
self.memory_monitor, # type: ignore
self.memory_monitor_interval * 1000,
)
worker.periodic_callbacks["memory_monitor"] = pc

self._throttled_gc = ThrottledGC(logger=logger)

async def memory_monitor(self, worker: Worker) -> None:
def get_process_memory(self) -> int:
"""Get a measure for process memory.
This can be a mock target.
"""
worker = self._worker()
if worker:
return worker.monitor.get_process_memory()
return -1

async def memory_monitor(self) -> None:
"""Track this process's memory usage and act accordingly.
If process memory rises above the spill threshold (70%), start dumping data to
disk until it goes below the target threshold (60%).
Expand All @@ -166,16 +176,18 @@ async def memory_monitor(self, worker: Worker) -> None:
# Don't use psutil directly; instead read from the same API that is used
# to send info to the Scheduler (e.g. for the benefit of Active Memory
# Manager) and which can be easily mocked in unit tests.
memory = worker.monitor.get_process_memory()
self._maybe_pause_or_unpause(worker, memory)
await self._maybe_spill(worker, memory)
memory = self.get_process_memory()
self._maybe_pause_or_unpause(memory)
await self._maybe_spill(memory)
finally:
self._memory_monitoring = False

def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
def _maybe_pause_or_unpause(self, memory: int) -> None:
if self.memory_pause_fraction is False:
return

worker = self._worker()
if not worker:
return
assert self.memory_limit
frac = memory / self.memory_limit
# Pause worker threads if above 80% memory use
Expand Down Expand Up @@ -205,7 +217,7 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
)
worker.status = Status.running

async def _maybe_spill(self, worker: Worker, memory: int) -> None:
async def _maybe_spill(self, memory: int) -> None:
if self.memory_spill_fraction is False:
return

Expand Down Expand Up @@ -257,15 +269,15 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
count += 1
await asyncio.sleep(0)

memory = worker.monitor.get_process_memory()
memory = self.get_process_memory()
if total_spilled > need and memory > target:
# Issue a GC to ensure that the evicted data is actually
# freed from memory and taken into account by the monitor
# before trying to evict even more data.
self._throttled_gc.collect()
memory = worker.monitor.get_process_memory()
memory = self.get_process_memory()

self._maybe_pause_or_unpause(worker, memory)
self._maybe_pause_or_unpause(memory)
if count:
logger.debug(
"Moved %d tasks worth %s to disk",
Expand All @@ -287,6 +299,7 @@ class NannyMemoryManager:
memory_limit: int | None
memory_terminate_fraction: float | Literal[False]
memory_monitor_interval: float | None
_nanny: weakref.ReferenceType[Nanny]

def __init__(
self,
Expand All @@ -302,32 +315,46 @@ def __init__(
dask.config.get("distributed.worker.memory.monitor-interval"),
default=None,
)
self._nanny = weakref.ref(nanny)
assert isinstance(self.memory_monitor_interval, (int, float))
if self.memory_limit and self.memory_terminate_fraction is not False:
pc = PeriodicCallback(
partial(self.memory_monitor, nanny),
self.memory_monitor,
self.memory_monitor_interval * 1000,
)
nanny.periodic_callbacks["memory_monitor"] = pc

def memory_monitor(self, nanny: Nanny) -> None:
def get_process_memory(self) -> int:
"""Get a measure for process memory.
This can be a mock target.
"""
nanny = self._nanny()
if nanny:
try:
proc = nanny._psutil_process
return proc.memory_info().rss
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
return -1 # pragma: nocover
return -1

def memory_monitor(self) -> None:
"""Track worker's memory. Restart if it goes above terminate fraction."""
if nanny.status != Status.running:
return # pragma: nocover
if nanny.process is None or nanny.process.process is None:
return # pragma: nocover
process = nanny.process.process
try:
proc = nanny._psutil_process
memory = proc.memory_info().rss
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
nanny = self._nanny()
if (
not nanny
or nanny.status != Status.running
or nanny.process is None
or nanny.process.process is None
or self.memory_limit is None
):
Comment on lines +348 to +349
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
or self.memory_limit is None
):
):
assert self.memory_limit is not None
assert self.memory_terminate_fraction is not False

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines below we're calculating memory / self.memory_limit, i.e. if memory_limit is None we can skip memory_monitor / return early.

So, why remove the self.memory_limit is None from the if clause and why add these asserts? If these asserts are a concerning edge case, we should probably rather add a unit test

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not an edge case; the whole method will never be scheduled if memory_limit is disabled. See lines ~140 in __init__. There are already several tests for it.

return # pragma: nocover

memory = self.get_process_memory()
if memory / self.memory_limit > self.memory_terminate_fraction:
logger.warning(
"Worker exceeded %d%% memory budget. Restarting",
100 * self.memory_terminate_fraction,
)
process = nanny.process.process
process.terminate()


Expand Down Expand Up @@ -403,4 +430,4 @@ def __get__(self, instance: Nanny | Worker | None, owner):
# This is triggered by Sphinx
return None # pragma: nocover
_warn_deprecated(instance, "memory_monitor")
return partial(instance.memory_manager.memory_monitor, instance)
return instance.memory_manager.memory_monitor