Skip to content

Commit

Permalink
Annotations and better tests for serialize_bytes (#8300)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Oct 24, 2023
1 parent 1655465 commit b4cfc0b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 13 deletions.
15 changes: 8 additions & 7 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,20 +680,21 @@ def serialize_bytelist(
return frames2


def serialize_bytes(x, **kwargs):
def serialize_bytes(x: object, **kwargs: Any) -> bytes:
L = serialize_bytelist(x, **kwargs)
return b"".join(L)


def deserialize_bytes(b):
def deserialize_bytes(b: bytes | bytearray | memoryview) -> Any:
"""Deserialize the output of :func:`serialize_bytes`"""
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, raw=False, use_list=False)
bin_header, frames = frames[0], frames[1:]
if bin_header:
header = msgpack.loads(bin_header, raw=False, use_list=False)
else:
header = {}
frames = decompress(header, frames)
return merge_and_deserialize(header, frames)
frames2 = decompress(header, frames)
return merge_and_deserialize(header, frames2)


################################
Expand Down
36 changes: 33 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,21 +265,51 @@ def test_empty_loads_deep():
assert isinstance(e2[0][0][0], Empty)


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes(kwargs):
for x in [
1,
"abc",
np.arange(5),
b"ab" * int(40e6),
int(2**26) * b"ab",
(int(2**25) * b"ab", int(2**25) * b"ab"),
]:
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert str(x) == str(y)
assert x == y


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes_numpy(kwargs):
x = np.arange(5)
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert (x == y).all()


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_read_only():
x = np.arange(5)
x.setflags(write=False)
blob = serialize_bytes(x, compression=False)
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
addr2 = x2.__array_interface__["data"][0]
addr3 = x3.__array_interface__["data"][0]
assert addr2 == addr3


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_writeable():
x = np.arange(5)
blob = bytearray(serialize_bytes(x, compression=False))
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
x2[0] = 123
assert x3[0] == 123


@pytest.mark.skipif(np is None, reason="Test needs numpy")
Expand Down
6 changes: 3 additions & 3 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ctypes
import struct
from collections.abc import Sequence
from collections.abc import Collection, Sequence

import dask

Expand Down Expand Up @@ -43,13 +43,13 @@ def frame_split_size(
return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)]


def pack_frames_prelude(frames):
def pack_frames_prelude(frames: Collection[bytes | bytearray | memoryview]) -> bytes:
nframes = len(frames)
nbytes_frames = map(nbytes, frames)
return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames)


def pack_frames(frames):
def pack_frames(frames: Collection[bytes | bytearray | memoryview]) -> bytes:
"""Pack frames into a byte-like object
This prepends length information to the front of the bytes-like object
Expand Down

0 comments on commit b4cfc0b

Please sign in to comment.