Skip to content

Commit

Permalink
Merge pull request #256 from jakirkham/use_dask_serialization_spill_t…
Browse files Browse the repository at this point in the history
…o_host

Use `"dask"` serialization to move to/from host
  • Loading branch information
pentschev authored Mar 26, 2020
2 parents a51b127 + 3f1577c commit 2dab1b6
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Respect `temporary-directory` config for spilling (#247) `John Kirkham`_
- Relax CuPy pin (#248) `John Kirkham`_
- Added `ignore_index` argument to `partition_by_hash()` (#253) `Mads R. B. Kristensen`_
- Use `"dask"` serialization to move to/from host (#256) `John Kirkham`_
- Drop Numba `DeviceNDArray` code for `sizeof` (#257) `John Kirkham`_
- Support spilling of device objects in dictionaries (#260) `Mads R. B. Kristensen`_

Expand Down
50 changes: 14 additions & 36 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import numpy

import dask
from distributed.protocol import (
dask_deserialize,
Expand All @@ -12,40 +14,32 @@
from distributed.utils import nbytes
from distributed.worker import weight

from numba import cuda
from zict import Buffer, File, Func
from zict.common import ZictBase

from .is_device_object import is_device_object

try:
from rmm import DeviceBuffer as cuda_memory_manager
except ImportError:
import numba.cuda as cuda_memory_manager


class DeviceSerialized:
""" Store device object on the host
This stores a device-side object as
1. A msgpack encodable header
2. A list of objects that are returned by calling
`numba.cuda.as_cuda_array(f).copy_to_host()`
which are typically NumPy arrays
2. A list of `bytes`-like objects (like NumPy arrays)
that are in host memory
"""

def __init__(self, header, parts, is_cuda):
def __init__(self, header, parts):
self.header = header
self.parts = parts
self.is_cuda = is_cuda

def __sizeof__(self):
return sum(map(nbytes, self.parts))


@dask_serialize.register(DeviceSerialized)
def _(obj):
def device_serialize(obj):
headers = []
all_frames = []
for part in obj.parts:
Expand All @@ -54,43 +48,30 @@ def _(obj):
headers.append(header)
all_frames.extend(frames)

header = {"sub-headers": headers, "is-cuda": obj.is_cuda, "main-header": obj.header}
header = {"sub-headers": headers, "main-header": obj.header}

return header, all_frames


@dask_deserialize.register(DeviceSerialized)
def _(header, frames):
def device_deserialize(header, frames):
parts = []
for sub_header in header["sub-headers"]:
start, stop = sub_header.pop("frame-start-stop")
part = deserialize(sub_header, frames[start:stop])
parts.append(part)

return DeviceSerialized(header["main-header"], parts, header["is-cuda"])


def copy_to_host(ary):
if hasattr(ary, "copy_to_host"):
return ary.copy_to_host()
else:
return cuda.as_cuda_array(ary).copy_to_host()
return DeviceSerialized(header["main-header"], parts)


def device_to_host(obj: object) -> DeviceSerialized:
header, frames = serialize(obj, serializers=["cuda", "pickle"])
is_cuda = [hasattr(f, "__cuda_array_interface__") for f in frames]
frames = [copy_to_host(f) if ic else f for ic, f in zip(is_cuda, frames)]
return DeviceSerialized(header, frames, is_cuda)
header, frames = serialize(obj, serializers=["dask", "pickle"])
frames = [numpy.asarray(f) for f in frames]
return DeviceSerialized(header, frames)


def host_to_device(s: DeviceSerialized) -> object:
frames = [
cuda_memory_manager.to_device(f.ravel().view("u1")) if ic else f
for ic, f in zip(s.is_cuda, s.parts)
]

return deserialize(s.header, frames)
return deserialize(s.header, s.parts)


class DeviceHostFile(ZictBase):
Expand All @@ -115,10 +96,7 @@ class DeviceHostFile(ZictBase):
"""

def __init__(
self,
device_memory_limit=None,
memory_limit=None,
local_directory=None,
self, device_memory_limit=None, memory_limit=None, local_directory=None,
):
if local_directory is None:
local_directory = dask.config.get("temporary-directory") or os.getcwd()
Expand Down
4 changes: 2 additions & 2 deletions dask_cuda/tests/test_device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def test_serialize_cupy_collection(collection, length, value):
if length > 5:
assert obj.header["serializer"] == "pickle"
elif length > 0:
assert all([h["serializer"] == "cuda" for h in obj.header["sub-headers"]])
assert all([h["serializer"] == "dask" for h in obj.header["sub-headers"]])
else:
assert obj.header["serializer"] == "cuda"
assert obj.header["serializer"] == "dask"

btslst = serialize_bytelist(obj)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dask>=2.9.0
distributed>=2.7.0
distributed>=2.11.0
pynvml>=8.0.3
numpy>=1.16.0
numba>=0.40.1

0 comments on commit 2dab1b6

Please sign in to comment.