Skip to content

Commit

Permalink
[v3] Buffer ensure correct subclass based on the BufferPrototype
Browse files Browse the repository at this point in the history
…argument (zarr-developers#1974)

* impl. and use Buffer.from_buffer()

* Update src/zarr/buffer.py

Co-authored-by: Davis Bennett <[email protected]>

* Apply suggestions from code review

Co-authored-by: Davis Bennett <[email protected]>

---------

Co-authored-by: Davis Bennett <[email protected]>
  • Loading branch information
madsbk and d-v-b authored Jun 25, 2024
1 parent e3ee09e commit c677da4
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 7 deletions.
25 changes: 24 additions & 1 deletion src/zarr/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def create_zero_length(cls) -> Self:

@classmethod
def from_array_like(cls, array_like: ArrayLike) -> Self:
"""Create a new buffer of a array-like object
"""Create a new buffer of an array-like object
Parameters
----------
Expand All @@ -159,6 +159,29 @@ def from_array_like(cls, array_like: ArrayLike) -> Self:
"""
return cls(array_like)

@classmethod
def from_buffer(cls, buffer: Buffer) -> Self:
"""Create a new buffer of an existing Buffer
This is useful if you want to ensure that an existing buffer is
of the correct subclass of Buffer. E.g., MemoryStore uses this
to return a buffer instance of the subclass specified by its
BufferPrototype argument.
Typically, this only copies data if the data has to be moved between
memory types, such as from host to device memory.
Parameters
----------
buffer
buffer object.
Returns
-------
A new buffer representing the content of the input buffer
"""
return cls.from_array_like(buffer.as_array_like())

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
"""Create a new buffer of a bytes-like object (host memory)
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def get(
try:
value = self._store_dict[key]
start, length = _normalize_interval_index(value, byte_range)
return value[start : start + length]
return prototype.buffer.from_buffer(value[start : start + length])
except KeyError:
return None

Expand Down
6 changes: 3 additions & 3 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import fsspec

from zarr.abc.store import Store
from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype
from zarr.buffer import Buffer, BufferPrototype
from zarr.common import OpenMode
from zarr.store.core import _dereference_path

Expand Down Expand Up @@ -84,7 +84,7 @@ def __repr__(self) -> str:
async def get(
self,
key: str,
prototype: BufferPrototype = default_buffer_prototype,
prototype: BufferPrototype,
byte_range: tuple[int | None, int | None] | None = None,
) -> Buffer | None:
path = _dereference_path(self.path, key)
Expand All @@ -99,7 +99,7 @@ async def get(
end = length
else:
end = None
value: Buffer = prototype.buffer.from_bytes(
value = prototype.buffer.from_bytes(
await (
self._fs._cat_file(path, start=byte_range[0], end=end)
if byte_range
Expand Down
5 changes: 4 additions & 1 deletion tests/v3/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ async def get(
) -> Buffer | None:
if "json" not in key:
assert prototype.buffer is MyBuffer
return await super().get(key, byte_range)
ret = await super().get(key=key, prototype=prototype, byte_range=byte_range)
if ret is not None:
assert isinstance(ret, prototype.buffer)
return ret


def test_nd_array_like(xp):
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def test_basic():
data = b"hello"
await store.set("foo", Buffer.from_bytes(data))
assert await store.exists("foo")
assert (await store.get("foo")).to_bytes() == data
assert (await store.get("foo", prototype=default_buffer_prototype)).to_bytes() == data
out = await store.get_partial_values(
prototype=default_buffer_prototype, key_ranges=[("foo", (1, None))]
)
Expand Down

0 comments on commit c677da4

Please sign in to comment.