diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 66c9687bae..d5694a1b3f 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -23,21 +23,15 @@ def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list: >>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP [b'123', b'45', b'678'] """ - if nbytes(frame) <= n: + frame = memoryview(frame) + + if frame.nbytes <= n: return [frame] - if nbytes(frame) > n: - if isinstance(frame, (bytes, bytearray)): - frame = memoryview(frame) - try: - itemsize = frame.itemsize - except AttributeError: - itemsize = 1 + nitems = frame.nbytes // frame.itemsize + items_per_shard = n // frame.itemsize - return [ - frame[i : i + n // itemsize] - for i in range(0, nbytes(frame) // itemsize, n // itemsize) - ] + return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)] def merge_frames(header, frames):