From c551dab274b1aa5b377444e8abd51140fff4d3a4 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 24 Oct 2023 01:13:15 +0200 Subject: [PATCH] Revert "Revert protocol changes" This reverts commit 7dc5375e4158d62437edf48ff1661c983c8daf2a. --- distributed/protocol/serialize.py | 15 ++++---- distributed/protocol/tests/test_serialize.py | 36 ++++++++++++++++++-- distributed/protocol/utils.py | 6 ++-- 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index ee41264a27..f25a2ce3c7 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -680,20 +680,21 @@ def serialize_bytelist( return frames2 -def serialize_bytes(x, **kwargs): +def serialize_bytes(x: object, **kwargs: Any) -> bytes: L = serialize_bytelist(x, **kwargs) return b"".join(L) -def deserialize_bytes(b): +def deserialize_bytes(b: bytes | bytearray | memoryview) -> Any: + """Deserialize the output of :func:`serialize_bytes`""" frames = unpack_frames(b) - header, frames = frames[0], frames[1:] - if header: - header = msgpack.loads(header, raw=False, use_list=False) + bin_header, frames = frames[0], frames[1:] + if bin_header: + header = msgpack.loads(bin_header, raw=False, use_list=False) else: header = {} - frames = decompress(header, frames) - return merge_and_deserialize(header, frames) + frames2 = decompress(header, frames) + return merge_and_deserialize(header, frames2) ################################ diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 975a1d55c8..8cc85b5db4 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -265,13 +265,11 @@ def test_empty_loads_deep(): assert isinstance(e2[0][0][0], Empty) -@pytest.mark.skipif(np is None, reason="Test needs numpy") @pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) def test_serialize_bytes(kwargs): for x in [ 1, "abc", - np.arange(5), b"ab" * int(40e6), int(2**26) * b"ab", (int(2**25) * b"ab", int(2**25) * b"ab"), @@ -279,7 +277,39 @@ def test_serialize_bytes(kwargs): b = serialize_bytes(x, **kwargs) assert isinstance(b, bytes) y = deserialize_bytes(b) - assert str(x) == str(y) + assert x == y + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) +def test_serialize_bytes_numpy(kwargs): + x = np.arange(5) + b = serialize_bytes(x, **kwargs) + assert isinstance(b, bytes) + y = deserialize_bytes(b) + assert (x == y).all() + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +def test_deserialize_bytes_zero_copy_read_only(): + x = np.arange(5) + x.setflags(write=False) + blob = serialize_bytes(x, compression=False) + x2 = deserialize_bytes(blob) + x3 = deserialize_bytes(blob) + addr2 = x2.__array_interface__["data"][0] + addr3 = x3.__array_interface__["data"][0] + assert addr2 == addr3 + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +def test_deserialize_bytes_zero_copy_writeable(): + x = np.arange(5) + blob = bytearray(serialize_bytes(x, compression=False)) + x2 = deserialize_bytes(blob) + x3 = deserialize_bytes(blob) + x2[0] = 123 + assert x3[0] == 123 @pytest.mark.skipif(np is None, reason="Test needs numpy") diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index e7d4b0f75c..201a7e3da1 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -2,7 +2,7 @@ import ctypes import struct -from collections.abc import Sequence +from collections.abc import Collection, Sequence import dask @@ -43,13 +43,13 @@ def frame_split_size( return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)] -def pack_frames_prelude(frames): +def pack_frames_prelude(frames: Collection[bytes | bytearray | memoryview]) -> bytes: nframes = len(frames) nbytes_frames = map(nbytes, frames) return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames) -def pack_frames(frames): +def pack_frames(frames: Collection[bytes | bytearray | memoryview]) -> bytes: """Pack frames into a byte-like object This prepends length information to the front of the bytes-like object