Skip to content

Commit

Permalink
Changed mind: multiple-memoryview case with copy
Browse files Browse the repository at this point in the history
A separate `copy_frames` function makes this more readable and easier to test. Also, I came up with a test for this case that's still contrived, but not ridiculously contrived.

That said, we don't want this copy to happen. And I'm pretty confident it will never happen with reall comms, because either the whole message is one buffer (TCP), or memoryviews aren't used at all. This mix-and-match only even happens in tests; see 1869b18. So maybe we should stick with the assert as a warning to future developers, so nobody messes this up and it keeps working with a silent performance regression?
  • Loading branch information
gjoseph92 committed Jul 31, 2021
1 parent 6ace0a4 commit ab6119a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 21 deletions.
12 changes: 2 additions & 10 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
msgpack_encode_default,
serialize_and_split,
)
from .utils import msgpack_opts
from .utils import copy_frames, msgpack_opts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,16 +119,8 @@ def _decode_default(obj):
# If given an initial offset for `frames[0]`, but we're working with a different
# buffer, something is wrong. We don't know if there's an initial offset for this buffer too,
# so to be safe, copy all sub-frames to new memory to prevent faulty zero-copy deserialization.
subframe_buffer = memoryview(b"".join(sub_frames))
i = 0
new_subframes = []
for frame in sub_frames:
new_subframes.append(
subframe_buffer[i : i + len(frame)]
)
i += len(frame)
sub_frames = copy_frames(sub_frames)
subframe_memoryview_offset = 0

break
subframe_memoryview_offset += len(f)

Expand Down
36 changes: 26 additions & 10 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import compressions
from distributed.protocol.serialize import Serialize, Serialized, deserialize, serialize
from distributed.protocol.utils import pack_frames, unpack_frames
from distributed.system import MEMORY_LIMIT
from distributed.utils import nbytes

Expand Down Expand Up @@ -220,17 +221,32 @@ def test_maybe_compress_memoryviews():
assert len(payload) < x.nbytes / 10


def test_loads_multiple_memoryviews_error():
msg = [Serialize(b"abcd"), b"wxyz"]
frames = dumps(msg)
header = frames[0]
assert not isinstance(header, memoryview)
assert any(isinstance(f, memoryview) for f in frames)
def test_loads_multiple_memoryviews():
"""Test that zero-copy deserialization isn't attempted when the frames come from different
buffers _and_ an overall offset is passed into `loads` (because the buffer behind the header frame
has some extra bytes at the beginning).
This case is highly contrived and should not occur in real-world use, but since we have logic to defend
against it, we want to test that it works.
"""
msg = [Serialize(b"ab" * int(2 ** 26))]

prelude = b"foobar"
# Pack/unpack the frames so they're all backed by the same bytestring buffer,
# which also happens to have a prelude (40 bytes currently).
real_prelude_len, frames = unpack_frames(pack_frames(dumps(msg)))
assert all(isinstance(f, memoryview) and f.obj is frames[0].obj for f in frames)

# Mess with the header frame, putting it in a different buffer with a different prelude length.
header = frames[0]
prelude = b"xxx"
assert len(prelude) != real_prelude_len
new_header = memoryview(prelude + header)[len(prelude) :]
assert new_header.obj is not header
assert new_header.obj is not header.obj
frames[0] = new_header

with pytest.raises(AssertionError, match="backed by multiple buffers"):
loads(frames, memoryview_offset=len(prelude))
# If we naively used this 3-byte prelude offset from the new header on all the memoryviews that actually
# needed at 40-byte offset, we'd deserialize incorrectly. Instead, `loads` detects this broken case and
# defensively copies the frames to a new buffer.
result = loads(frames, memoryview_offset=len(prelude))
correct = result == loads(dumps(msg))
assert correct # Prevents slow pytest diffs if assert fails
27 changes: 26 additions & 1 deletion distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from distributed.protocol.utils import pack_frames, unpack_frames
import pytest

from distributed.protocol.utils import copy_frames, pack_frames, unpack_frames


def test_pack_frames():
Expand All @@ -9,3 +11,26 @@ def test_pack_frames():

assert frames == frames2
assert prelude_size == len(b) - sum(len(x) for x in frames)


@pytest.mark.parametrize(
"frames",
[
[],
[b"123"],
[b"123", b"asdf"],
[memoryview(b"abcd")[1:], b"x", memoryview(b"12345678")],
],
)
def test_copy_frames(frames):
new = copy_frames(frames)
assert [bytes(f) for f in new] == [bytes(x) for x in frames]
assert all(isinstance(f, memoryview) for f in new)
if frames:
new_buffers = set(f.obj for f in new)
assert len(new_buffers) == 1
if len(frames) != 1 and not isinstance(frames[0], bytes):
# `b"".join([b"123"])` is zero-copy. We are okay allowing this optimization.
assert not new_buffers.intersection(
f.obj if isinstance(f, memoryview) else f for f in frames
)
14 changes: 14 additions & 0 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,17 @@ def unpack_frames(b):
start = end

return prelude_size, frames


def copy_frames(frames):
"""Copy frames into new contiguous memory.
Returns a duplicate frames list of memoryviews referencing the new memory.
"""
buffer = memoryview(b"".join(frames))
i = 0
new_frames = []
for frame in frames:
new_frames.append(buffer[i : i + len(frame)])
i += len(frame)
return new_frames

0 comments on commit ab6119a

Please sign in to comment.