From 6916076ffc033ec32daba43c8176fc99ed6d3ba2 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 26 Mar 2020 19:44:30 -0700 Subject: [PATCH] Simplify `DeviceSerialized` and usage thereof As `"dask"` serialization already converts a CUDA object into headers and frames that Dask is able to work with, drop code that tries to serialize frames on host further (as they are already as simple as they can be). Cuts a fair bit of boilerplate from the spilling path, which should simplify things a bit. --- dask_cuda/device_host_file.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 91e4451e..7d1766bd 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -30,38 +30,24 @@ class DeviceSerialized: that are in host memory """ - def __init__(self, header, parts): + def __init__(self, header, frames): self.header = header - self.parts = parts + self.frames = frames def __sizeof__(self): - return sum(map(nbytes, self.parts)) + return sum(map(nbytes, self.frames)) @dask_serialize.register(DeviceSerialized) def device_serialize(obj): - headers = [] - all_frames = [] - for part in obj.parts: - header, frames = serialize(part) - header["frame-start-stop"] = [len(all_frames), len(all_frames) + len(frames)] - headers.append(header) - all_frames.extend(frames) - - header = {"sub-headers": headers, "main-header": obj.header} - - return header, all_frames + header = {"main-header": dict(obj.header)} + frames = list(obj.frames) + return header, frames @dask_deserialize.register(DeviceSerialized) def device_deserialize(header, frames): - parts = [] - for sub_header in header["sub-headers"]: - start, stop = sub_header.pop("frame-start-stop") - part = deserialize(sub_header, frames[start:stop]) - parts.append(part) - - return DeviceSerialized(header["main-header"], parts) + return DeviceSerialized(header["main-header"], frames) def device_to_host(obj: object) -> DeviceSerialized: @@ -71,7 +57,7 @@ def device_to_host(obj: object) -> DeviceSerialized: def host_to_device(s: DeviceSerialized) -> object: - return deserialize(s.header, s.parts) + return deserialize(s.header, s.frames) class DeviceHostFile(ZictBase):