diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index 91e4451e4..7d1766bd7 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -30,38 +30,24 @@ class DeviceSerialized: that are in host memory """ - def __init__(self, header, parts): + def __init__(self, header, frames): self.header = header - self.parts = parts + self.frames = frames def __sizeof__(self): - return sum(map(nbytes, self.parts)) + return sum(map(nbytes, self.frames)) @dask_serialize.register(DeviceSerialized) def device_serialize(obj): - headers = [] - all_frames = [] - for part in obj.parts: - header, frames = serialize(part) - header["frame-start-stop"] = [len(all_frames), len(all_frames) + len(frames)] - headers.append(header) - all_frames.extend(frames) - - header = {"sub-headers": headers, "main-header": obj.header} - - return header, all_frames + header = {"main-header": dict(obj.header)} + frames = list(obj.frames) + return header, frames @dask_deserialize.register(DeviceSerialized) def device_deserialize(header, frames): - parts = [] - for sub_header in header["sub-headers"]: - start, stop = sub_header.pop("frame-start-stop") - part = deserialize(sub_header, frames[start:stop]) - parts.append(part) - - return DeviceSerialized(header["main-header"], parts) + return DeviceSerialized(header["main-header"], frames) def device_to_host(obj: object) -> DeviceSerialized: @@ -71,7 +57,7 @@ def device_to_host(obj: object) -> DeviceSerialized: def host_to_device(s: DeviceSerialized) -> object: - return deserialize(s.header, s.parts) + return deserialize(s.header, s.frames) class DeviceHostFile(ZictBase):