Skip to content

Commit

Permalink
Changed mind again: back to AssertionError
Browse files Browse the repository at this point in the history
As noted in ab6119a: I think this error is currently impossible to raise in real use, and we want to keep it that way. We'd like a future test to fail if something causes this case to happen, rather than a silent copy.
  • Loading branch information
gjoseph92 committed Jul 31, 2021
1 parent ab6119a commit e2da610
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 53 deletions.
12 changes: 5 additions & 7 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
msgpack_encode_default,
serialize_and_split,
)
from .utils import copy_frames, msgpack_opts
from .utils import msgpack_opts

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,12 +115,10 @@ def _decode_default(obj):
if not isinstance(f, memoryview) or f.obj is not obj:
# Walking backwards from the start of `sub_frames`, reached a frame that doesn't
# belong to the same memoryview
if memoryview_offset != 0:
# If given an initial offset for `frames[0]`, but we're working with a different
# buffer, something is wrong. We don't know if there's an initial offset for this buffer too,
# so to be safe, copy all sub-frames to new memory to prevent faulty zero-copy deserialization.
sub_frames = copy_frames(sub_frames)
subframe_memoryview_offset = 0
assert memoryview_offset == 0, (
f"Given an initial offset of {memoryview_offset} into the frames' underlying buffer "
"but the frames are backed by multiple buffers. This should not happen."
)
break
subframe_memoryview_offset += len(f)

Expand Down
11 changes: 5 additions & 6 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def test_maybe_compress_memoryviews():
assert len(payload) < x.nbytes / 10


def test_loads_multiple_memoryviews():
def test_loads_multiple_memoryviews_error():
"""Test that zero-copy deserialization isn't attempted when the frames come from different
buffers _and_ an overall offset is passed into `loads` (because the buffer behind the header frame
has some extra bytes at the beginning).
This case is highly contrived and should not occur in real-world use, but since we have logic to defend
against it, we want to test that it works.
This case is highly contrived and should not occur in real-world use, but since we have an assert to catch
it, we want to test that it works.
"""
msg = [Serialize(b"ab" * int(2 ** 26))]

Expand All @@ -247,6 +247,5 @@ def test_loads_multiple_memoryviews():
# If we naively used this 3-byte prelude offset from the new header on all the memoryviews that actually
# needed at 40-byte offset, we'd deserialize incorrectly. Instead, `loads` detects this broken case and
# defensively copies the frames to a new buffer.
result = loads(frames, memoryview_offset=len(prelude))
correct = result == loads(dumps(msg))
assert correct # Prevents slow pytest diffs if assert fails
with pytest.raises(AssertionError, match="backed by multiple buffers"):
loads(frames, memoryview_offset=len(prelude))
27 changes: 1 addition & 26 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import pytest

from distributed.protocol.utils import copy_frames, pack_frames, unpack_frames
from distributed.protocol.utils import pack_frames, unpack_frames


def test_pack_frames():
Expand All @@ -11,26 +9,3 @@ def test_pack_frames():

assert frames == frames2
assert prelude_size == len(b) - sum(len(x) for x in frames)


@pytest.mark.parametrize(
"frames",
[
[],
[b"123"],
[b"123", b"asdf"],
[memoryview(b"abcd")[1:], b"x", memoryview(b"12345678")],
],
)
def test_copy_frames(frames):
new = copy_frames(frames)
assert [bytes(f) for f in new] == [bytes(x) for x in frames]
assert all(isinstance(f, memoryview) for f in new)
if frames:
new_buffers = set(f.obj for f in new)
assert len(new_buffers) == 1
if len(frames) != 1 and not isinstance(frames[0], bytes):
# `b"".join([b"123"])` is zero-copy. We are okay allowing this optimization.
assert not new_buffers.intersection(
f.obj if isinstance(f, memoryview) else f for f in frames
)
14 changes: 0 additions & 14 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,3 @@ def unpack_frames(b):
start = end

return prelude_size, frames


def copy_frames(frames):
"""Copy frames into new contiguous memory.
Returns a duplicate frames list of memoryviews referencing the new memory.
"""
buffer = memoryview(b"".join(frames))
i = 0
new_frames = []
for frame in frames:
new_frames.append(buffer[i : i + len(frame)])
i += len(frame)
return new_frames

0 comments on commit e2da610

Please sign in to comment.