Skip to content

Commit

Permalink
Spilling on demand (#756)
Browse files Browse the repository at this point in the history
Use rapidsai/rmm#892 to implement spilling on demand. Requires use of [RMM](https://github.com/rapidsai/rmm) and JIT-unspill enabled.

The `device_memory_limit` still works as usual -- when known allocations gets to `device_memory_limit`, Dask-CUDA starts spilling preemptively. However, with this PR it is should be possible to increase `device_memory_limit` significantly since memory spikes will be handled by spilling on demand.

Closes #755

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #756
  • Loading branch information
madsbk authored Oct 29, 2021
1 parent 1d547a0 commit b52d1d6
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
61 changes: 61 additions & 0 deletions dask_cuda/proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,27 @@ def maybe_evict_from_device(self, extra_dev_mem=0) -> None:
serialized_proxies.add(id(p))
p._pxy_serialize(serializers=("dask", "pickle"))

def force_evict_from_device(self, nbytes=0) -> int:
freed_memory = 0
proxies_to_serialize: List[ProxyObject] = []
with self.lock:
_, dev_buf_access = self.get_dev_access_info()
dev_buf_access.sort(key=lambda x: (x[0], -x[1]))
for _, size, proxies in dev_buf_access:
for p in proxies:
proxies_to_serialize.append(p)
freed_memory += size
if freed_memory >= nbytes:
break

serialized_proxies: Set[int] = set()
for p in proxies_to_serialize:
# Avoid trying to serialize the same proxy multiple times
if id(p) not in serialized_proxies:
serialized_proxies.add(id(p))
p._pxy_serialize(serializers=("dask", "pickle"))
return freed_memory

def maybe_evict_from_host(self, extra_host_mem=0) -> None:
if ( # Shortcut when not evicting
self._host.mem_usage() + extra_host_mem <= self._host_memory_limit
Expand Down Expand Up @@ -418,6 +439,10 @@ class ProxifyHostFile(MutableMapping):
ProxyObjects, set the `mark_as_explicit_proxies=True` when proxifying with
`proxify_device_objects()`. If ``None``, the "jit-unspill-compatibility-mode"
config value are used, which defaults to False.
spill_on_demand: bool or None, default None
Enables spilling when the RMM memory pool goes out of memory. If ``None``,
the "spill-on-demand" config value are used, which defaults to True.
Notice, enabling this does nothing when RMM isn't availabe or not used.
"""

# Notice, we define the following as static variables because they are used by
Expand All @@ -436,6 +461,7 @@ def __init__(
local_directory: str = None,
shared_filesystem: bool = None,
compatibility_mode: bool = None,
spill_on_demand: bool = None,
):
self.store: Dict[Hashable, Any] = {}
self.manager = ProxyManager(device_memory_limit, memory_limit)
Expand All @@ -446,6 +472,10 @@ def __init__(
)
else:
self.compatibility_mode = compatibility_mode
if spill_on_demand is None:
spill_on_demand = dask.config.get("spill-on-demand", default=True)
# `None` in this context means: never initialize
self.spill_on_demand_initialized = False if spill_on_demand else None

# It is a bit hacky to forcefully capture the "distributed.worker" logger,
# eventually it would be better to have a different logger. For now this
Expand All @@ -463,6 +493,36 @@ def __iter__(self):
with self.lock:
return iter(self.store)

def initialize_spill_on_demand_once(self):
"""Register callback function to handle RMM out-of-memory exceptions
This function is idempotent and should be called at least once. Currently, we
do this in __setitem__ instead of in __init__ because a Dask worker might re-
initiate the RMM pool and its resource adaptors after creating ProxifyHostFile.
"""
if self.spill_on_demand_initialized is False:
self.spill_on_demand_initialized = True
try:
import rmm.mr

assert hasattr(rmm.mr, "FailureCallbackResourceAdaptor")
except (ImportError, AssertionError):
pass
else:

def oom(nbytes: int) -> bool:
"""Try to handle an out-of-memory error by spilling"""
freed = self.manager.force_evict_from_device(nbytes)
if freed > 0:
return True # Ask RMM to retry the allocation
else:
# Since we didn't find anything to spill, we give up.
return False

current_mr = rmm.mr.get_current_device_resource()
mr = rmm.mr.FailureCallbackResourceAdaptor(current_mr, oom)
rmm.mr.set_current_device_resource(mr)

@property
def fast(self):
"""Dask use this to trigger CPU-to-Disk spilling"""
Expand All @@ -478,6 +538,7 @@ def evict():

def __setitem__(self, key, value):
with self.lock:
self.initialize_spill_on_demand_once()
if key in self.store:
# Make sure we register the removal of an existing key
del self[key]
Expand Down
21 changes: 21 additions & 0 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dask_cuda.get_device_memory_objects import get_device_memory_objects
from dask_cuda.proxify_host_file import ProxifyHostFile
from dask_cuda.proxy_object import ProxyObject, asproxy
from dask_cuda.utils import get_device_total_memory

cupy = pytest.importorskip("cupy")
cupy.cuda.set_allocator(None)
Expand Down Expand Up @@ -171,6 +172,26 @@ def test_one_item_host_limit():
assert is_proxies_equal(dhf.manager._dev.get_proxies(), [k1])


def test_spill_on_demand():
"""
Test spilling on demand by disabling the device_memory_limit
and allocating two large buffers that will otherwise fail because
of spilling on demand.
"""
rmm = pytest.importorskip("rmm")
if not hasattr(rmm.mr, "FailureCallbackResourceAdaptor"):
pytest.skip("RMM doesn't implement FailureCallbackResourceAdaptor")

total_mem = get_device_total_memory()
dhf = ProxifyHostFile(
device_memory_limit=2 * total_mem,
memory_limit=2 * total_mem,
spill_on_demand=True,
)
for i in range(2):
dhf[i] = rmm.DeviceBuffer(size=total_mem // 2 + 1)


@pytest.mark.parametrize("jit_unspill", [True, False])
def test_local_cuda_cluster(jit_unspill):
"""Testing spilling of a proxied cudf dataframe in a local cuda cluster"""
Expand Down

0 comments on commit b52d1d6

Please sign in to comment.