Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🛑 DNM Deserialization: zero-copy merge subframes when possible #5112

Closed
17 changes: 16 additions & 1 deletion distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,23 @@ def _decode_default(obj):
if deserialize:
if "compression" in sub_header:
sub_frames = decompress(sub_header, sub_frames)

# Check for memoryviews in preceding frames that share an underlying
# buffer with these sub-frames, to figure out what offset in that buffer
# `sub_frames` starts at.
memoryview_offset = 0
if sub_frames and isinstance(sub_frames[0], memoryview):
obj = sub_frames[0].obj
for f in reversed(frames[:offset]):
if not (isinstance(f, memoryview) and f.obj is obj):
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
break
memoryview_offset += len(f)

return merge_and_deserialize(
sub_header, sub_frames, deserializers=deserializers
sub_header,
sub_frames,
deserializers=deserializers,
memoryview_offset=memoryview_offset,
)
else:
return Serialized(sub_header, sub_frames)
Expand Down
55 changes: 50 additions & 5 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import itertools
import traceback
from array import array
from enum import Enum
Expand Down Expand Up @@ -450,30 +451,74 @@ def serialize_and_split(
return header, out_frames


def merge_and_deserialize(header, frames, deserializers=None):
def merge_and_deserialize(header, frames, deserializers=None, memoryview_offset=0):
"""Merge and deserialize frames

This function is a drop-in replacement of `deserialize()` that merges
frames that were split by `serialize_and_split()`
This function is a replacement for `deserialize()` that merges
frames that were split by `serialize_and_split()`.

When ``frames`` contains memoryviews that share an underlying buffer,
``memoryview_offset`` must be the index into that underlying buffer
where the ``frames`` starts (in bytes, not frame counts).
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved

See Also
--------
deserialize
serialize_and_split
merge_subframes
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
frame_byte_offsets = list(itertools.accumulate(map(len, frames)))
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
merged = frames[offset]
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))
subframes = frames[offset : offset + n]
merged = merge_subframes(
subframes,
memoryview_offset=memoryview_offset
+ (frame_byte_offsets[offset - 1] if offset else 0),
)
merged_frames.append(merged)
memoryview_offset += len(merged)

return deserialize(header, merged_frames, deserializers=deserializers)


def merge_subframes(
subframes: "list[memoryview | bytearray | bytes]", memoryview_offset: int = 0
) -> "memoryview | bytearray":
gjoseph92 marked this conversation as resolved.
Show resolved Hide resolved
"""Merge a list of frames into one buffer.

If all frames are memoryviews backed by the same underlying buffer,
this is zero-copy. ``memoryview_offset`` must be the index into that
underlying buffer where the subframes start (in bytes, not frame counts),
such that ``memoryview(subframes[0].obj)[memoryview_offset] == subframes[0][0]``.

Otherwise, all frames are copied into a new contiguous bytearray.

See Also
--------
merge_and_deserialize
"""
if subframes:
first = subframes[0]
if isinstance(first, memoryview) and first.contiguous:
obj = first.obj
try:
same_buffer = all(f.obj is obj for f in subframes[1:])
except AttributeError:
same_buffer = False
if same_buffer:
end_offset = memoryview_offset + sum(len(f) for f in subframes)
return memoryview(obj)[memoryview_offset:end_offset]

return bytearray().join(subframes)


class Serialize:
"""Mark an object that should be serialized

Expand Down
4 changes: 3 additions & 1 deletion distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,14 @@ def test_dumps_serialize_numpy_large():
frames = dumps([to_serialize(x)])
dtype, shape = x.dtype, x.shape
checksum = crc32(x)
del x
[y] = loads(frames)

assert (y.dtype, y.shape) == (dtype, shape)
assert crc32(y) == checksum, "Arrays are unequal"

x[:] = 2 # shared buffer; serialization is zero-copy
assert (x == y).all(), "Data was copied"


@pytest.mark.parametrize(
"dt,size",
Expand Down