-
-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ArrowRecordBatchCodec and vlen string support #2031
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
import pyarrow as pa | ||
|
||
from zarr.abc.codec import ArrayBytesCodec | ||
from zarr.array_spec import ArraySpec | ||
from zarr.buffer import Buffer, NDBuffer | ||
from zarr.codecs.registry import register_codec | ||
from zarr.common import JSON, parse_named_configuration | ||
|
||
if TYPE_CHECKING: | ||
from typing_extensions import Self | ||
|
||
CHUNK_FIELD_NAME = "zarr_chunk" | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ArrowRecordBatchCodec(ArrayBytesCodec): | ||
def __init__(self) -> None: | ||
pass | ||
|
||
@classmethod | ||
def from_dict(cls, data: dict[str, JSON]) -> Self: | ||
_, configuration_parsed = parse_named_configuration( | ||
data, "arrow", require_configuration=False | ||
) | ||
configuration_parsed = configuration_parsed or {} | ||
return cls(**configuration_parsed) | ||
|
||
def to_dict(self) -> dict[str, JSON]: | ||
return {"name": "arrow"} | ||
|
||
def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: | ||
return self | ||
|
||
async def _decode_single( | ||
self, | ||
chunk_bytes: Buffer, | ||
chunk_spec: ArraySpec, | ||
) -> NDBuffer: | ||
assert isinstance(chunk_bytes, Buffer) | ||
|
||
# TODO: make this compatible with buffer prototype | ||
arrow_buffer = memoryview(chunk_bytes.to_bytes()) | ||
with pa.ipc.open_stream(arrow_buffer) as reader: | ||
batches = [b for b in reader] | ||
assert len(batches) == 1 | ||
arrow_array = batches[0][CHUNK_FIELD_NAME] | ||
chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like( | ||
arrow_array.to_numpy(zero_copy_only=False) | ||
) | ||
|
||
# ensure correct chunk shape | ||
if chunk_array.shape != chunk_spec.shape: | ||
chunk_array = chunk_array.reshape( | ||
chunk_spec.shape, | ||
) | ||
return chunk_array | ||
|
||
async def _encode_single( | ||
self, | ||
chunk_array: NDBuffer, | ||
chunk_spec: ArraySpec, | ||
) -> Buffer | None: | ||
assert isinstance(chunk_array, NDBuffer) | ||
arrow_array = pa.array(chunk_array.as_ndarray_like().ravel()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should probably be In theory, it would be possible to do zero-copy transfers for CuPy arrays too but would need to go from CuPy -> Numba first and then Numba -> Arrow. |
||
rb = pa.record_batch([arrow_array], names=[CHUNK_FIELD_NAME]) | ||
# TODO: allocate buffer differently | ||
sink = pa.BufferOutputStream() | ||
with pa.ipc.new_stream(sink, rb.schema) as writer: | ||
writer.write_batch(rb) | ||
return chunk_spec.prototype.buffer.from_bytes(memoryview(sink.getvalue())) | ||
|
||
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: | ||
raise ValueError("Don't know how to compute encoded size!") | ||
|
||
|
||
register_codec("arrow", ArrowRecordBatchCodec) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from zarr.abc.store import Store | ||
from zarr.array import Array | ||
from zarr.codecs import ArrowRecordBatchCodec | ||
from zarr.store.core import StorePath | ||
|
||
|
||
@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) | ||
@pytest.mark.parametrize( | ||
"dtype", | ||
[ | ||
"uint8", | ||
"uint16", | ||
"uint32", | ||
"uint64", | ||
"int8", | ||
"int16", | ||
"int32", | ||
"int64", | ||
"float32", | ||
"float64", | ||
], | ||
) | ||
def test_arrow_standard_dtypes(store: Store, dtype) -> None: | ||
data = np.arange(0, 256, dtype=dtype).reshape((16, 16)) | ||
|
||
a = Array.create( | ||
StorePath(store, path="arrow"), | ||
shape=data.shape, | ||
chunk_shape=(16, 16), | ||
dtype=data.dtype, | ||
fill_value=0, | ||
codecs=[ArrowRecordBatchCodec()], | ||
) | ||
|
||
a[:, :] = data | ||
assert np.array_equal(data, a[:, :]) | ||
|
||
|
||
@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) | ||
def test_arrow_vlen_string(store: Store) -> None: | ||
strings = ["hello", "world", "this", "is", "a", "test"] | ||
data = np.array(strings).reshape((2, 3)) | ||
|
||
a = Array.create( | ||
StorePath(store, path="arrow"), | ||
shape=data.shape, | ||
chunk_shape=data.shape, | ||
dtype=data.dtype, | ||
fill_value=0, | ||
codecs=[ArrowRecordBatchCodec()], | ||
) | ||
|
||
a[:, :] = data | ||
assert np.array_equal(data, a[:, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL: https://pypi.org/project/nanoarrow