From c05b9d154a53d2223124b4ef59c83b58d8571698 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 14 Jul 2024 12:49:14 -0400 Subject: [PATCH 01/29] add legacy vlen-utf8 codec --- src/zarr/buffer.py | 5 ++- src/zarr/codecs/__init__.py | 2 + src/zarr/codecs/legacy_vlen.py | 68 +++++++++++++++++++++++++++++++ tests/v3/test_codecs/test_vlen.py | 27 ++++++++++++ 4 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 src/zarr/codecs/legacy_vlen.py create mode 100644 tests/v3/test_codecs/test_vlen.py diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 86f9b53477..bec992508e 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -283,7 +283,10 @@ class NDBuffer: def __init__(self, array: NDArrayLike): # assert array.ndim > 0 - assert array.dtype != object + + # Commented this out because string arrays have dtype object + # TODO: decide how to handle strings (e.g. numpy 2.0 StringDtype) + # assert array.dtype != object self._data = array @classmethod diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 9394284319..c795e5e243 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -4,6 +4,7 @@ from zarr.codecs.bytes import BytesCodec, Endian from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.gzip import GzipCodec +from zarr.codecs.legacy_vlen import VLenUTF8Codec from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec @@ -21,5 +22,6 @@ "ShardingCodec", "ShardingCodecIndexLocation", "TransposeCodec", + "VLenUTF8Codec", "ZstdCodec", ] diff --git a/src/zarr/codecs/legacy_vlen.py b/src/zarr/codecs/legacy_vlen.py new file mode 100644 index 0000000000..76f3d30390 --- /dev/null +++ b/src/zarr/codecs/legacy_vlen.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from numcodecs.vlen import VLenUTF8 + +from zarr.abc.codec import ArrayBytesCodec +from zarr.array_spec import ArraySpec +from zarr.buffer import Buffer, NDBuffer +from zarr.codecs.registry import register_codec +from zarr.common import JSON, parse_named_configuration + +if TYPE_CHECKING: + from typing_extensions import Self + + +# can use a global because there are no parameters +vlen_utf8_codec = VLenUTF8() + + +@dataclass(frozen=True) +class VLenUTF8Codec(ArrayBytesCodec): + def __init__(self) -> None: + pass + + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "vlen-utf8", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) + + def to_dict(self) -> dict[str, JSON]: + return {"name": "vlen-utf8"} + + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + return self + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + assert isinstance(chunk_bytes, Buffer) + + raw_bytes = chunk_bytes.as_array_like() + decoded = vlen_utf8_codec.decode(raw_bytes) + decoded.shape = chunk_spec.shape + return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + vlen_utf8_codec.encode(chunk_array.as_numpy_array()) + ) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + # what is input_byte_length for an object dtype? + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") + + +register_codec("vlen-utf8", VLenUTF8Codec) diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py new file mode 100644 index 0000000000..1e2702708a --- /dev/null +++ b/tests/v3/test_codecs/test_vlen.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from zarr.abc.store import Store +from zarr.array import Array +from zarr.codecs import VLenUTF8Codec +from zarr.store.core import StorePath + + +@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) +def test_arrow_vlen_string(store: Store) -> None: + strings = ["hello", "world", "this", "is", "a", "test"] + data = np.array(strings).reshape((2, 3)) + + a = Array.create( + StorePath(store, path="arrow"), + shape=data.shape, + chunk_shape=data.shape, + dtype=data.dtype, + fill_value=0, + codecs=[VLenUTF8Codec()], + ) + + a[:, :] = data + print(a) + print(a[:]) + assert np.array_equal(data, a[:, :]) From a32212457e3c511456e3e64b1def07ad04708584 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 29 Sep 2024 18:29:53 -0400 Subject: [PATCH 02/29] got it working again --- src/zarr/codecs/legacy_vlen.py | 11 ++++++----- src/zarr/core/metadata/v3.py | 3 +++ tests/v3/test_codecs/test_vlen.py | 14 ++++++-------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/zarr/codecs/legacy_vlen.py b/src/zarr/codecs/legacy_vlen.py index 76f3d30390..10405c8e87 100644 --- a/src/zarr/codecs/legacy_vlen.py +++ b/src/zarr/codecs/legacy_vlen.py @@ -6,13 +6,14 @@ from numcodecs.vlen import VLenUTF8 from zarr.abc.codec import ArrayBytesCodec -from zarr.array_spec import ArraySpec -from zarr.buffer import Buffer, NDBuffer -from zarr.codecs.registry import register_codec -from zarr.common import JSON, parse_named_configuration +from zarr.core.buffer import Buffer, NDBuffer +from zarr.core.common import JSON, parse_named_configuration +from zarr.registry import register_codec if TYPE_CHECKING: - from typing_extensions import Self + from typing import Self + + from zarr.core.array_spec import ArraySpec # can use a global because there are no parameters diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index f0c6dc6282..c61c1058a0 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -445,6 +445,7 @@ class DataType(Enum): float64 = "float64" complex64 = "complex64" complex128 = "complex128" + string = "string" @property def byte_count(self) -> int: @@ -492,6 +493,8 @@ def to_numpy_shortname(self) -> str: @classmethod def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: + if np.issubdtype(np.str_, dtype): + return DataType.string dtype_to_data_type = { "|b1": "bool", "bool": "bool", diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 1e2702708a..2b82f3d500 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -1,27 +1,25 @@ import numpy as np import pytest +from zarr import Array from zarr.abc.store import Store -from zarr.array import Array from zarr.codecs import VLenUTF8Codec -from zarr.store.core import StorePath +from zarr.store.common import StorePath -@pytest.mark.parametrize("store", ("local", "memory"), indirect=["store"]) -def test_arrow_vlen_string(store: Store) -> None: +@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) +def test_vlen_string(store: Store) -> None: strings = ["hello", "world", "this", "is", "a", "test"] data = np.array(strings).reshape((2, 3)) a = Array.create( - StorePath(store, path="arrow"), + StorePath(store, path="string"), shape=data.shape, chunk_shape=data.shape, dtype=data.dtype, - fill_value=0, + fill_value="", codecs=[VLenUTF8Codec()], ) a[:, :] = data - print(a) - print(a[:]) assert np.array_equal(data, a[:, :]) From 2a1e2e32745ddfeb75553d8c8d7c028ce9285720 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 30 Sep 2024 20:02:52 -0400 Subject: [PATCH 03/29] got strings working; broke everything else --- src/zarr/codecs/legacy_vlen.py | 5 +--- src/zarr/core/buffer/core.py | 2 +- src/zarr/core/config.py | 1 + src/zarr/core/metadata/v3.py | 47 +++++++++++++++++++++++-------- tests/v3/test_codecs/test_vlen.py | 19 +++++++++++-- 5 files changed, 55 insertions(+), 19 deletions(-) diff --git a/src/zarr/codecs/legacy_vlen.py b/src/zarr/codecs/legacy_vlen.py index 10405c8e87..19f2e293ca 100644 --- a/src/zarr/codecs/legacy_vlen.py +++ b/src/zarr/codecs/legacy_vlen.py @@ -22,9 +22,6 @@ @dataclass(frozen=True) class VLenUTF8Codec(ArrayBytesCodec): - def __init__(self) -> None: - pass - @classmethod def from_dict(cls, data: dict[str, JSON]) -> Self: _, configuration_parsed = parse_named_configuration( @@ -34,7 +31,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: return cls(**configuration_parsed) def to_dict(self) -> dict[str, JSON]: - return {"name": "vlen-utf8"} + return {"name": "vlen-utf8", "configuration": {}} def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index 9a808e08b4..4cc3f217f1 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -472,7 +472,7 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool: # use array_equal to obtain equal_nan=True functionality _data, other = np.broadcast_arrays(self._data, other) return np.array_equal( - self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "US" else False + self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "UST" else False ) def fill(self, value: Any) -> None: diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 735755616f..2b8b27e6ef 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -58,6 +58,7 @@ def reset(self) -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.legacy_vlen.VLenUTF8Codec", }, "buffer": "zarr.core.buffer.cpu.Buffer", "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index c61c1058a0..d7e96f91b5 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -152,7 +152,7 @@ def _replace_special_floats(obj: object) -> Any: @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(ArrayMetadata): shape: ChunkCoords - data_type: np.dtype[Any] + data_type: DataType chunk_grid: ChunkGrid chunk_key_encoding: ChunkKeyEncoding fill_value: Any @@ -167,7 +167,7 @@ def __init__( self, *, shape: Iterable[int], - data_type: npt.DTypeLike, + data_type: str | np.dtype[Any] | DataType, chunk_grid: dict[str, JSON] | ChunkGrid, chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, fill_value: Any, @@ -180,18 +180,18 @@ def __init__( Because the class is a frozen dataclass, we set attributes using object.__setattr__ """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(data_type) + data_type_parsed = DataType.parse(data_type) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) + fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy_dtype()) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type_parsed, + dtype=data_type_parsed.to_numpy_dtype(), fill_value=fill_value_parsed, order="C", # TODO: order is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. @@ -224,11 +224,14 @@ def _validate_metadata(self) -> None: if self.fill_value is None: raise ValueError("`fill_value` is required.") for codec in self.codecs: - codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid) + codec.validate( + shape=self.shape, dtype=self.data_type.to_numpy_dtype(), chunk_grid=self.chunk_grid + ) @property def dtype(self) -> np.dtype[Any]: - return self.data_type + """Interpret Zarr dtype as NumPy dtype""" + return self.data_type.to_numpy_dtype() @property def ndim(self) -> int: @@ -266,7 +269,7 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _ = parse_node_type_array(_data.pop("node_type")) # check that the data_type attribute is valid - _ = DataType(_data["data_type"]) + _data["data_type"] = DataType.parse(_data.pop("data_type")) # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) @@ -310,6 +313,7 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: FLOAT = np.float16 | np.float32 | np.float64 COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType COMPLEX = np.complex64 | np.complex128 +STRING = np.str_ @overload @@ -491,8 +495,14 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] + def to_numpy_dtype(self) -> np.dtype[Any]: + if self == DataType.string: + return np.dtypes.StringDType() + else: + return np.dtype(self.to_numpy_shortname()) + @classmethod - def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: + def from_numpy_dtype(cls, dtype: np.dtype[Any]) -> DataType: if np.issubdtype(np.str_, dtype): return DataType.string dtype_to_data_type = { @@ -514,15 +524,30 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: } return DataType[dtype_to_data_type[dtype.str]] + @classmethod + def parse(cls, dtype: str | np.dtype[Any] | DataType) -> DataType: + if isinstance(dtype, DataType): + return dtype + elif isinstance(dtype, np.dtype): + return cls.from_numpy_dtype(dtype) + elif isinstance(dtype, str): + try: + return cls(dtype) + except ValueError as e: + raise TypeError(f"Invalid V3 data_type: {dtype}") from e + else: + raise TypeError(f"Invalid V3 data_type: {dtype}") + -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: +def numpy_dtype_to_zarr_data_type(data: npt.DTypeLike) -> DataType: try: dtype = np.dtype(data) except (ValueError, TypeError) as e: raise ValueError(f"Invalid V3 data_type: {data}") from e # check that this is a valid v3 data_type try: - _ = DataType.from_dtype(dtype) + # dtype = DataType.from_dtype(dtype) + _ = DataType.from_numpy_dtype(dtype) except KeyError as e: raise ValueError(f"Invalid V3 data_type: {dtype}") from e diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 2b82f3d500..9fdcfab193 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -4,16 +4,21 @@ from zarr import Array from zarr.abc.store import Store from zarr.codecs import VLenUTF8Codec +from zarr.core.metadata.v3 import DataType from zarr.store.common import StorePath -@pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) -def test_vlen_string(store: Store) -> None: +@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) +@pytest.mark.parametrize("dtype", [None, np.dtypes.StrDType]) +async def test_vlen_string(store: Store, dtype) -> None: strings = ["hello", "world", "this", "is", "a", "test"] data = np.array(strings).reshape((2, 3)) + if dtype is not None: + data = data.astype(dtype) + sp = StorePath(store, path="string") a = Array.create( - StorePath(store, path="string"), + sp, shape=data.shape, chunk_shape=data.shape, dtype=data.dtype, @@ -23,3 +28,11 @@ def test_vlen_string(store: Store) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) + assert a.metadata.data_type == DataType.string + assert a.dtype == np.dtypes.StringDType() + + # test round trip + b = Array.open(sp) + assert np.array_equal(data, b[:, :]) + assert b.metadata.data_type == DataType.string + assert b.dtype == np.dtypes.StringDType() From 1d3d7a5590f0ecabb47993d49a391d50639aa40e Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 30 Sep 2024 21:35:49 -0400 Subject: [PATCH 04/29] change v3.metadata.data_type type --- src/zarr/core/metadata/v3.py | 60 ++++++++++++++++++------------- tests/v3/test_metadata/test_v3.py | 11 +++--- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index f0c6dc6282..91126f310c 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from typing import Self - import numpy.typing as npt - from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import JSON, ChunkCoords @@ -20,6 +18,7 @@ import numcodecs.abc import numpy as np +import numpy.typing as npt from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArraySpec @@ -152,7 +151,7 @@ def _replace_special_floats(obj: object) -> Any: @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(ArrayMetadata): shape: ChunkCoords - data_type: np.dtype[Any] + data_type: DataType chunk_grid: ChunkGrid chunk_key_encoding: ChunkKeyEncoding fill_value: Any @@ -167,7 +166,7 @@ def __init__( self, *, shape: Iterable[int], - data_type: npt.DTypeLike, + data_type: npt.DTypeLike | DataType, chunk_grid: dict[str, JSON] | ChunkGrid, chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, fill_value: Any, @@ -180,18 +179,18 @@ def __init__( Because the class is a frozen dataclass, we set attributes using object.__setattr__ """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(data_type) + data_type_parsed = DataType.parse(data_type) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) + fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy_dtype()) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type_parsed, + dtype=data_type_parsed.to_numpy_dtype(), fill_value=fill_value_parsed, order="C", # TODO: order is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. @@ -224,11 +223,14 @@ def _validate_metadata(self) -> None: if self.fill_value is None: raise ValueError("`fill_value` is required.") for codec in self.codecs: - codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid) + codec.validate( + shape=self.shape, dtype=self.data_type.to_numpy_dtype(), chunk_grid=self.chunk_grid + ) @property def dtype(self) -> np.dtype[Any]: - return self.data_type + """Interpret Zarr dtype as NumPy dtype""" + return self.data_type.to_numpy_dtype() @property def ndim(self) -> int: @@ -266,13 +268,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _ = parse_node_type_array(_data.pop("node_type")) # check that the data_type attribute is valid - _ = DataType(_data["data_type"]) + data_type = DataType.parse(_data.pop("data_type")) # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) # attributes key is optional, normalize missing to `None` _data["attributes"] = _data.pop("attributes", None) - return cls(**_data) # type: ignore[arg-type] + return cls(**_data, data_type=data_type) # type: ignore[arg-type] def to_dict(self) -> dict[str, JSON]: out_dict = super().to_dict() @@ -490,8 +492,11 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] + def to_numpy_dtype(self) -> np.dtype[Any]: + return np.dtype(self.to_numpy_shortname()) + @classmethod - def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: + def from_numpy_dtype(cls, dtype: np.dtype[Any]) -> DataType: dtype_to_data_type = { "|b1": "bool", "bool": "bool", @@ -511,16 +516,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: } return DataType[dtype_to_data_type[dtype.str]] - -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: - try: - dtype = np.dtype(data) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid V3 data_type: {data}") from e - # check that this is a valid v3 data_type - try: - _ = DataType.from_dtype(dtype) - except KeyError as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - - return dtype + @classmethod + def parse(cls, dtype: None | DataType | Any) -> DataType: + if dtype is None: + # the default dtype + return DataType.float64 + if isinstance(dtype, DataType): + return dtype + else: + try: + dtype = np.dtype(dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + # check that this is a valid v3 data_type + try: + data_type = DataType.from_numpy_dtype(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + return data_type diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 71dc917c35..534ef61d09 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -7,7 +7,7 @@ from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding -from zarr.core.metadata.v3 import ArrayV3Metadata +from zarr.core.metadata.v3 import ArrayV3Metadata, DataType if TYPE_CHECKING: from collections.abc import Sequence @@ -22,7 +22,6 @@ from zarr.core.metadata.v3 import ( parse_dimension_names, - parse_dtype, parse_fill_value, parse_zarr_format, ) @@ -209,7 +208,7 @@ def test_metadata_to_dict( storage_transformers: None | tuple[dict[str, JSON]], ) -> None: shape = (1, 2, 3) - data_type = "uint8" + data_type = DataType.uint8 if chunk_grid == "regular": cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}} @@ -290,7 +289,7 @@ def test_metadata_to_dict( # assert result["fill_value"] == fill_value -async def test_invalid_dtype_raises() -> None: +def test_invalid_dtype_raises() -> None: metadata_dict = { "zarr_format": 3, "node_type": "array", @@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None: "codecs": (), "fill_value": np.datetime64(0, "ns"), } - with pytest.raises(ValueError, match=r".* is not a valid DataType"): + with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"): ArrayV3Metadata.from_dict(metadata_dict) @pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()]) def test_parse_invalid_dtype_raises(data): with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"): - parse_dtype(data) + DataType.parse(data) @pytest.mark.parametrize( From 988f9df66d66719332f16c52c23ab0b62722403c Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 30 Sep 2024 22:09:55 -0400 Subject: [PATCH 05/29] fixed tests --- src/zarr/core/metadata/v3.py | 25 ++++++++++++++----------- tests/v3/test_config.py | 1 + 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 818a27a7bf..52bfdb2419 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -530,14 +530,17 @@ def parse(cls, dtype: None | DataType | Any) -> DataType: return DataType.float64 if isinstance(dtype, DataType): return dtype - else: - try: - dtype = np.dtype(dtype) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - # check that this is a valid v3 data_type - try: - data_type = DataType.from_numpy_dtype(dtype) - except KeyError as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - return data_type + try: + return DataType(dtype) + except ValueError: + pass + try: + dtype = np.dtype(dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + # check that this is a valid v3 data_type + try: + data_type = DataType.from_numpy_dtype(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + return data_type diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index e324367b3d..25b5b4fcd1 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -58,6 +58,7 @@ def test_config_defaults_set() -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", + "vlen-utf8": "zarr.codecs.legacy_vlen.VLenUTF8Codec", }, } ] From 507161a99721fe3b6ec2e631482ff797899ee225 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 30 Sep 2024 22:18:10 -0400 Subject: [PATCH 06/29] satisfy mypy for tests --- tests/v3/test_codecs/test_vlen.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 9fdcfab193..2a5aaa606c 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -1,16 +1,18 @@ +from typing import Any + import numpy as np import pytest from zarr import Array from zarr.abc.store import Store from zarr.codecs import VLenUTF8Codec -from zarr.core.metadata.v3 import DataType +from zarr.core.metadata.v3 import ArrayV3Metadata, DataType from zarr.store.common import StorePath @pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) @pytest.mark.parametrize("dtype", [None, np.dtypes.StrDType]) -async def test_vlen_string(store: Store, dtype) -> None: +async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: strings = ["hello", "world", "this", "is", "a", "test"] data = np.array(strings).reshape((2, 3)) if dtype is not None: @@ -25,6 +27,7 @@ async def test_vlen_string(store: Store, dtype) -> None: fill_value="", codecs=[VLenUTF8Codec()], ) + assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy a[:, :] = data assert np.array_equal(data, a[:, :]) @@ -33,6 +36,7 @@ async def test_vlen_string(store: Store, dtype) -> None: # test round trip b = Array.open(sp) + assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == DataType.string assert b.dtype == np.dtypes.StringDType() From 1ae5e6315b1dcf3b89d7c8e14b78f6bb5577071e Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Thu, 3 Oct 2024 16:26:01 -0400 Subject: [PATCH 07/29] make strings work --- src/zarr/codecs/legacy_vlen.py | 7 ++++++- src/zarr/core/buffer/core.py | 9 ++++----- src/zarr/core/metadata/v3.py | 4 ++-- tests/v3/test_codecs/test_vlen.py | 15 ++++++++++++--- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/zarr/codecs/legacy_vlen.py b/src/zarr/codecs/legacy_vlen.py index 19f2e293ca..c793328f6f 100644 --- a/src/zarr/codecs/legacy_vlen.py +++ b/src/zarr/codecs/legacy_vlen.py @@ -3,12 +3,14 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +import numpy as np from numcodecs.vlen import VLenUTF8 from zarr.abc.codec import ArrayBytesCodec from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import JSON, parse_named_configuration from zarr.registry import register_codec +from zarr.strings import cast_to_string_dtype if TYPE_CHECKING: from typing import Self @@ -45,8 +47,11 @@ async def _decode_single( raw_bytes = chunk_bytes.as_array_like() decoded = vlen_utf8_codec.decode(raw_bytes) + assert decoded.dtype == np.object_ decoded.shape = chunk_spec.shape - return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) + # coming out of the code, we know this is safe, so don't issue a warning + as_string_dtype = cast_to_string_dtype(decoded, safe=True) + return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype) async def _encode_single( self, diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index 4cc3f217f1..b520e21ee3 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -313,10 +313,6 @@ class NDBuffer: """ def __init__(self, array: NDArrayLike) -> None: - # assert array.ndim > 0 - - # Commented this out because string arrays have dtype object - # TODO: decide how to handle strings (e.g. numpy 2.0 StringDtype) # assert array.dtype != object self._data = array @@ -470,9 +466,12 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool: # Handle None fill_value for Zarr V2 return False # use array_equal to obtain equal_nan=True functionality + # Note from Ryan: doesn't this lead to a huge amount of unnecessary memory allocation on every single chunk? + # Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value + # every single time we have to write data? _data, other = np.broadcast_arrays(self._data, other) return np.array_equal( - self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "UST" else False + self._data, other, equal_nan=equal_nan if self._data.dtype.kind not in "USTO" else False ) def fill(self, value: Any) -> None: diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 52bfdb2419..d83d354cea 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -29,6 +29,7 @@ from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.registry import get_codec_class +from zarr.strings import STRING_DTYPE def parse_zarr_format(data: object) -> Literal[3]: @@ -312,7 +313,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: FLOAT = np.float16 | np.float32 | np.float64 COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType COMPLEX = np.complex64 | np.complex128 -STRING = np.str_ @overload @@ -496,7 +496,7 @@ def to_numpy_shortname(self) -> str: def to_numpy_dtype(self) -> np.dtype[Any]: if self == DataType.string: - return np.dtypes.StringDType() + return STRING_DTYPE else: return np.dtype(self.to_numpy_shortname()) diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 2a5aaa606c..03df907a2f 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -8,10 +8,19 @@ from zarr.codecs import VLenUTF8Codec from zarr.core.metadata.v3 import ArrayV3Metadata, DataType from zarr.store.common import StorePath +from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING + +numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType] +expected_zarr_string_dtype: np.dtype[Any] +if NUMPY_SUPPORTS_VLEN_STRING: + numpy_str_dtypes.append(np.dtypes.StringDType) + expected_zarr_string_dtype = np.dtypes.StringDType() +else: + expected_zarr_string_dtype = np.dtype("O") @pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) -@pytest.mark.parametrize("dtype", [None, np.dtypes.StrDType]) +@pytest.mark.parametrize("dtype", numpy_str_dtypes) async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: strings = ["hello", "world", "this", "is", "a", "test"] data = np.array(strings).reshape((2, 3)) @@ -32,11 +41,11 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) assert a.metadata.data_type == DataType.string - assert a.dtype == np.dtypes.StringDType() + assert a.dtype == expected_zarr_string_dtype # test round trip b = Array.open(sp) assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == DataType.string - assert b.dtype == np.dtypes.StringDType() + assert a.dtype == expected_zarr_string_dtype From 94ecdb5889942d28966afbb7129a7d8ad5e1cd57 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Thu, 3 Oct 2024 16:30:27 -0400 Subject: [PATCH 08/29] add missing module --- src/zarr/strings.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 src/zarr/strings.py diff --git a/src/zarr/strings.py b/src/zarr/strings.py new file mode 100644 index 0000000000..c3dd1a3395 --- /dev/null +++ b/src/zarr/strings.py @@ -0,0 +1,36 @@ +from typing import Any +from warnings import warn + +import numpy as np + +try: + STRING_DTYPE = np.dtype("T") + NUMPY_SUPPORTS_VLEN_STRING = True +except TypeError: + STRING_DTYPE = np.dtype("object") + NUMPY_SUPPORTS_VLEN_STRING = False + + +def cast_to_string_dtype( + data: np.ndarray[Any, np.dtype[Any]], safe: bool = False +) -> np.ndarray[Any, np.dtype[Any]]: + if np.issubdtype(data.dtype, np.str_): + return data + if np.issubdtype(data.dtype, np.object_): + if NUMPY_SUPPORTS_VLEN_STRING: + try: + # cast to variable-length string dtype, fail if object contains non-string data + # mypy says "error: Unexpected keyword argument "coerce" for "StringDType" [call-arg]" + return data.astype(np.dtypes.StringDType(coerce=False), copy=False) # type: ignore[call-arg] + except ValueError as e: + raise ValueError("Cannot cast object dtype to string dtype") from e + else: + out = data.astype(np.str_) + if not safe: + warn( + f"Casted object dtype to string dtype {out.dtype}. To avoid this warning, " + "cast the data to a string dtype before passing to Zarr or upgrade to NumPy >= 2.", + stacklevel=2, + ) + return out + raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") From 79b7d4337c9414ee66aebc263071617fb48e4612 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Fri, 4 Oct 2024 08:42:49 -0400 Subject: [PATCH 09/29] store -> storage --- tests/v3/test_codecs/test_vlen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 03df907a2f..7e30d1140f 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -7,7 +7,7 @@ from zarr.abc.store import Store from zarr.codecs import VLenUTF8Codec from zarr.core.metadata.v3 import ArrayV3Metadata, DataType -from zarr.store.common import StorePath +from zarr.storage.common import StorePath from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType] From a5c2a37470d41dbe49dc091e189dc546332148c6 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Fri, 4 Oct 2024 08:46:19 -0400 Subject: [PATCH 10/29] rename module --- src/zarr/codecs/__init__.py | 2 +- src/zarr/codecs/{legacy_vlen.py => vlen_utf8.py} | 0 src/zarr/core/config.py | 2 +- tests/v3/test_config.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename src/zarr/codecs/{legacy_vlen.py => vlen_utf8.py} (100%) diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index c795e5e243..1d90b2651c 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -4,10 +4,10 @@ from zarr.codecs.bytes import BytesCodec, Endian from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.gzip import GzipCodec -from zarr.codecs.legacy_vlen import VLenUTF8Codec from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.vlen_utf8 import VLenUTF8Codec from zarr.codecs.zstd import ZstdCodec __all__ = [ diff --git a/src/zarr/codecs/legacy_vlen.py b/src/zarr/codecs/vlen_utf8.py similarity index 100% rename from src/zarr/codecs/legacy_vlen.py rename to src/zarr/codecs/vlen_utf8.py diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 2b8b27e6ef..d79db0d47c 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -58,7 +58,7 @@ def reset(self) -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", - "vlen-utf8": "zarr.codecs.legacy_vlen.VLenUTF8Codec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", }, "buffer": "zarr.core.buffer.cpu.Buffer", "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index 25b5b4fcd1..f59241ebfb 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -58,7 +58,7 @@ def test_config_defaults_set() -> None: "crc32c": "zarr.codecs.crc32c_.Crc32cCodec", "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", - "vlen-utf8": "zarr.codecs.legacy_vlen.VLenUTF8Codec", + "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", }, } ] From 0406ea1b1e12b29d93b983dff4374520721d2a89 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 09:50:49 -0400 Subject: [PATCH 11/29] add vlen bytes --- src/zarr/codecs/__init__.py | 3 +- src/zarr/codecs/vlen_utf8.py | 48 ++++++++++++++++++++++++++++++- src/zarr/core/config.py | 1 + src/zarr/core/metadata/v3.py | 11 ++++++- tests/v3/test_codecs/test_vlen.py | 32 ++++++++++++++++++++- tests/v3/test_config.py | 1 + 6 files changed, 92 insertions(+), 4 deletions(-) diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 1d90b2651c..693c47d622 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -7,7 +7,7 @@ from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.codecs.transpose import TransposeCodec -from zarr.codecs.vlen_utf8 import VLenUTF8Codec +from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.codecs.zstd import ZstdCodec __all__ = [ @@ -23,5 +23,6 @@ "ShardingCodecIndexLocation", "TransposeCodec", "VLenUTF8Codec", + "VLenBytesCodec", "ZstdCodec", ] diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index c793328f6f..dd21f18999 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING import numpy as np -from numcodecs.vlen import VLenUTF8 +from numcodecs.vlen import VLenBytes, VLenUTF8 from zarr.abc.codec import ArrayBytesCodec from zarr.core.buffer import Buffer, NDBuffer @@ -20,6 +20,7 @@ # can use a global because there are no parameters vlen_utf8_codec = VLenUTF8() +vlen_bytes_codec = VLenBytes() @dataclass(frozen=True) @@ -68,4 +69,49 @@ def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) - raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") +@dataclass(frozen=True) +class VLenBytesCodec(ArrayBytesCodec): + @classmethod + def from_dict(cls, data: dict[str, JSON]) -> Self: + _, configuration_parsed = parse_named_configuration( + data, "vlen-bytes", require_configuration=False + ) + configuration_parsed = configuration_parsed or {} + return cls(**configuration_parsed) + + def to_dict(self) -> dict[str, JSON]: + return {"name": "vlen-bytes", "configuration": {}} + + def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: + return self + + async def _decode_single( + self, + chunk_bytes: Buffer, + chunk_spec: ArraySpec, + ) -> NDBuffer: + assert isinstance(chunk_bytes, Buffer) + + raw_bytes = chunk_bytes.as_array_like() + decoded = vlen_bytes_codec.decode(raw_bytes) + assert decoded.dtype == np.object_ + decoded.shape = chunk_spec.shape + return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) + + async def _encode_single( + self, + chunk_array: NDBuffer, + chunk_spec: ArraySpec, + ) -> Buffer | None: + assert isinstance(chunk_array, NDBuffer) + return chunk_spec.prototype.buffer.from_bytes( + vlen_bytes_codec.encode(chunk_array.as_numpy_array()) + ) + + def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: + # what is input_byte_length for an object dtype? + raise NotImplementedError("compute_encoded_size is not implemented for VLen codecs") + + register_codec("vlen-utf8", VLenUTF8Codec) +register_codec("vlen-bytes", VLenBytesCodec) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index d79db0d47c..3fe7d803d2 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -59,6 +59,7 @@ def reset(self) -> None: "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", }, "buffer": "zarr.core.buffer.cpu.Buffer", "ndbuffer": "zarr.core.buffer.cpu.NDBuffer", diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 9375434e07..357c569b51 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -386,6 +386,8 @@ def parse_fill_value( """ if fill_value is None: return dtype.type(0) + if dtype.kind == "O": + return fill_value if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): if dtype.type in (np.complex64, np.complex128): dtype = cast(COMPLEX_DTYPE, dtype) @@ -451,6 +453,7 @@ class DataType(Enum): complex64 = "complex64" complex128 = "complex128" string = "string" + bytes = "bytes" @property def byte_count(self) -> int: @@ -499,13 +502,19 @@ def to_numpy_shortname(self) -> str: def to_numpy(self) -> np.dtype[Any]: if self == DataType.string: return STRING_DTYPE + elif self == DataType.bytes: + # for now always use object dtype for bytestrings + # TODO: consider whether we can use fixed-width types (e.g. '|S5') instead + return np.dtype("O") else: return np.dtype(self.to_numpy_shortname()) @classmethod def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: - if np.issubdtype(np.str_, dtype): + if dtype.kind in "UT": return DataType.string + elif dtype.kind == "S": + return DataType.bytes dtype_to_data_type = { "|b1": "bool", "bool": "bool", diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index 7e30d1140f..e87fe84c9e 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -5,7 +5,7 @@ from zarr import Array from zarr.abc.store import Store -from zarr.codecs import VLenUTF8Codec +from zarr.codecs import VLenBytesCodec, VLenUTF8Codec from zarr.core.metadata.v3 import ArrayV3Metadata, DataType from zarr.storage.common import StorePath from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING @@ -49,3 +49,33 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == DataType.string assert a.dtype == expected_zarr_string_dtype + + +@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) +async def test_vlen_bytes(store: Store) -> None: + bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"] + data = np.array(bstrings).reshape((2, 3)) + assert data.dtype == "|S5" + + sp = StorePath(store, path="string") + a = Array.create( + sp, + shape=data.shape, + chunk_shape=data.shape, + dtype=data.dtype, + fill_value=b"", + codecs=[VLenBytesCodec()], + ) + assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy + + a[:, :] = data + assert np.array_equal(data, a[:, :]) + # assert a.metadata.data_type == DataType.string + # assert a.dtype == expected_zarr_string_dtype + + # test round trip + b = Array.open(sp) + assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy + assert np.array_equal(data, b[:, :]) + # assert b.metadata.data_type == DataType.string + # assert a.dtype == expected_zarr_string_dtype diff --git a/tests/v3/test_config.py b/tests/v3/test_config.py index f59241ebfb..2adc51aa57 100644 --- a/tests/v3/test_config.py +++ b/tests/v3/test_config.py @@ -59,6 +59,7 @@ def test_config_defaults_set() -> None: "sharding_indexed": "zarr.codecs.sharding.ShardingCodec", "transpose": "zarr.codecs.transpose.TransposeCodec", "vlen-utf8": "zarr.codecs.vlen_utf8.VLenUTF8Codec", + "vlen-bytes": "zarr.codecs.vlen_utf8.VLenBytesCodec", }, } ] From 8e61a18cd67ba5a013c2cb613b0708a9aad196a3 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 10:11:31 -0400 Subject: [PATCH 12/29] fix type assertions in test --- tests/v3/test_codecs/test_vlen.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index e87fe84c9e..e6de2ed767 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -70,12 +70,12 @@ async def test_vlen_bytes(store: Store) -> None: a[:, :] = data assert np.array_equal(data, a[:, :]) - # assert a.metadata.data_type == DataType.string - # assert a.dtype == expected_zarr_string_dtype + assert a.metadata.data_type == DataType.bytes + assert a.dtype == "O" # test round trip b = Array.open(sp) assert isinstance(b.metadata, ArrayV3Metadata) # needed for mypy assert np.array_equal(data, b[:, :]) - # assert b.metadata.data_type == DataType.string - # assert a.dtype == expected_zarr_string_dtype + assert b.metadata.data_type == DataType.bytes + assert a.dtype == "O" From 6cf7dde6214450970661a267f7409217f62e4830 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 14:54:47 -0400 Subject: [PATCH 13/29] much better validation of fill value --- src/zarr/codecs/__init__.py | 17 +++++++++++ src/zarr/core/array.py | 4 +-- src/zarr/core/metadata/v3.py | 44 ++++++++++++++++++++++++++-- tests/v3/test_codecs/test_vlen.py | 48 +++++++++++++++++++++++++------ tests/v3/test_metadata/test_v3.py | 19 ++++++++---- 5 files changed, 114 insertions(+), 18 deletions(-) diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 693c47d622..15c0e4c8f6 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -1,5 +1,9 @@ from __future__ import annotations +from typing import Any + +import numpy as np + from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian from zarr.codecs.crc32c_ import Crc32cCodec @@ -9,6 +13,7 @@ from zarr.codecs.transpose import TransposeCodec from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec from zarr.codecs.zstd import ZstdCodec +from zarr.core.metadata.v3 import DataType __all__ = [ "BatchedCodecPipeline", @@ -26,3 +31,15 @@ "VLenBytesCodec", "ZstdCodec", ] + + +def get_default_array_bytes_codec( + np_dtype: np.dtype[Any], +) -> BytesCodec | VLenUTF8Codec | VLenBytesCodec: + dtype = DataType.from_numpy(np_dtype) + if dtype == DataType.string: + return VLenUTF8Codec() + elif dtype == DataType.bytes: + return VLenBytesCodec() + else: + return BytesCodec() diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 9a78297c6f..51003b6b16 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -11,7 +11,7 @@ from zarr._compat import _deprecate_positional_args from zarr.abc.store import Store, set_or_delete -from zarr.codecs import BytesCodec +from zarr.codecs import get_default_array_bytes_codec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.core.attributes import Attributes from zarr.core.buffer import ( @@ -318,7 +318,7 @@ async def _create_v3( await ensure_no_existing_node(store_path, zarr_format=3) shape = parse_shapelike(shape) - codecs = list(codecs) if codecs is not None else [BytesCodec()] + codecs = list(codecs) if codecs is not None else [get_default_array_bytes_codec(dtype)] if chunk_key_encoding is None: chunk_key_encoding = ("default", "/") diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 357c569b51..858a26fc7b 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -64,6 +64,34 @@ def parse_codecs(data: object) -> tuple[Codec, ...]: return out +def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None: + """Check that the codecs are valid for the given dtype""" + + # ensure that we have at least one ArrayBytesCodec + abcs: list[ArrayBytesCodec] = [] + for codec in codecs: + if isinstance(codec, ArrayBytesCodec): + abcs.append(codec) + if len(abcs) == 0: + raise ValueError("At least one ArrayBytesCodec is required.") + elif len(abcs) > 1: + raise ValueError("Only one ArrayBytesCodec is allowed.") + + abc = abcs[0] + + # we need to have special codecs if we are decoding vlen strings or bytestrings + # TODO: use codec ID instead of class name + codec_id = abc.__class__.__name__ + if dtype == DataType.string and not codec_id == "VLenUTF8Codec": + raise ValueError( + f"For string dtype, ArrayBytesCodec must be `VLenUTF8Codec`, got `{codec_id}`." + ) + if dtype == DataType.bytes and not codec_id == "VLenBytesCodec": + raise ValueError( + f"For bytes dtype, ArrayBytesCodec must be `VLenBytesCodec`, got `{codec_id}`." + ) + + def parse_dimension_names(data: object) -> tuple[str | None, ...] | None: if data is None: return data @@ -186,6 +214,8 @@ def __init__( chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) + if fill_value is None: + fill_value = default_fill_value(data_type_parsed) fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy()) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) @@ -199,6 +229,7 @@ def __init__( prototype=default_buffer_prototype(), # TODO: prototype is not needed here. ) codecs_parsed = [c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial] + validate_codecs(codecs_parsed_partial, data_type_parsed) object.__setattr__(self, "shape", shape_parsed) object.__setattr__(self, "data_type", data_type_parsed) @@ -360,8 +391,17 @@ def parse_fill_value( ... +def default_fill_value(dtype: DataType) -> str | bytes | np.generic: + if dtype == DataType.string: + return "" + elif dtype == DataType.bytes: + return b"" + else: + return dtype.to_numpy().type(0) + + def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any], ) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: """ @@ -385,7 +425,7 @@ def parse_fill_value( A scalar instance of `dtype` """ if fill_value is None: - return dtype.type(0) + raise ValueError("Fill value cannot be None") if dtype.kind == "O": return fill_value if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): diff --git a/tests/v3/test_codecs/test_vlen.py b/tests/v3/test_codecs/test_vlen.py index e6de2ed767..eeed72b696 100644 --- a/tests/v3/test_codecs/test_vlen.py +++ b/tests/v3/test_codecs/test_vlen.py @@ -4,8 +4,9 @@ import pytest from zarr import Array +from zarr.abc.codec import Codec from zarr.abc.store import Store -from zarr.codecs import VLenBytesCodec, VLenUTF8Codec +from zarr.codecs import VLenBytesCodec, VLenUTF8Codec, ZstdCodec from zarr.core.metadata.v3 import ArrayV3Metadata, DataType from zarr.storage.common import StorePath from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING @@ -21,11 +22,13 @@ @pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) @pytest.mark.parametrize("dtype", numpy_str_dtypes) -async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: +@pytest.mark.parametrize("as_object_array", [False, True]) +@pytest.mark.parametrize("codecs", [None, [VLenUTF8Codec()], [VLenUTF8Codec(), ZstdCodec()]]) +def test_vlen_string( + store: Store, dtype: None | np.dtype[Any], as_object_array: bool, codecs: None | list[Codec] +) -> None: strings = ["hello", "world", "this", "is", "a", "test"] - data = np.array(strings).reshape((2, 3)) - if dtype is not None: - data = data.astype(dtype) + data = np.array(strings, dtype=dtype).reshape((2, 3)) sp = StorePath(store, path="string") a = Array.create( @@ -34,10 +37,15 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: chunk_shape=data.shape, dtype=data.dtype, fill_value="", - codecs=[VLenUTF8Codec()], + codecs=codecs, ) assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy + # should also work if input array is an object array, provided we explicitly specified + # a stringlike dtype when creating the Array + if as_object_array: + data = data.astype("O") + a[:, :] = data assert np.array_equal(data, a[:, :]) assert a.metadata.data_type == DataType.string @@ -52,7 +60,9 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None: @pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"]) -async def test_vlen_bytes(store: Store) -> None: +@pytest.mark.parametrize("as_object_array", [False, True]) +@pytest.mark.parametrize("codecs", [None, [VLenBytesCodec()], [VLenBytesCodec(), ZstdCodec()]]) +def test_vlen_bytes(store: Store, as_object_array: bool, codecs: None | list[Codec]) -> None: bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"] data = np.array(bstrings).reshape((2, 3)) assert data.dtype == "|S5" @@ -64,10 +74,14 @@ async def test_vlen_bytes(store: Store) -> None: chunk_shape=data.shape, dtype=data.dtype, fill_value=b"", - codecs=[VLenBytesCodec()], + codecs=codecs, ) assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy + # should also work if input array is an object array, provided we explicitly specified + # a bytesting-like dtype when creating the Array + if as_object_array: + data = data.astype("O") a[:, :] = data assert np.array_equal(data, a[:, :]) assert a.metadata.data_type == DataType.bytes @@ -79,3 +93,21 @@ async def test_vlen_bytes(store: Store) -> None: assert np.array_equal(data, b[:, :]) assert b.metadata.data_type == DataType.bytes assert a.dtype == "O" + + +@pytest.mark.parametrize("store", ["memory"], indirect=["store"]) +def test_vlen_errors(store: Store) -> None: + sp = StorePath(store, path="string") + + # fill value must be a compatible type + with pytest.raises(ValueError, match="fill value 0 is not valid"): + Array.create(sp, shape=5, chunk_shape=5, dtype=" None: @pytest.mark.parametrize("dtype_str", dtypes) -def test_parse_auto_fill_value(dtype_str: str) -> None: +def test_default_fill_value(dtype_str: str) -> None: """ Test that parse_fill_value(None, dtype) results in the 0 value for the given dtype. """ - dtype = np.dtype(dtype_str) - fill_value = None - assert parse_fill_value(fill_value, dtype) == dtype.type(0) + dtype = DataType(dtype_str) + fill_value = default_fill_value(dtype) + if dtype == DataType.string: + assert fill_value == "" + elif dtype == DataType.bytes: + assert fill_value == b"" + else: + assert fill_value == dtype.to_numpy().type(0) @pytest.mark.parametrize( @@ -337,7 +344,7 @@ async def test_special_float_fill_values(fill_value: str) -> None: "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}}, "data_type": "float64", "chunk_key_encoding": {"name": "default", "separator": "."}, - "codecs": (), + "codecs": [{"name": "bytes"}], "fill_value": fill_value, # this is not a valid fill value for uint8 } m = ArrayV3Metadata.from_dict(metadata_dict) From 28d58fad199e4336a96eef793f872d9f07a497fc Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 18:44:35 -0400 Subject: [PATCH 14/29] retype parse_fill_value --- src/zarr/core/metadata/v3.py | 34 +++++++++++++++---------- tests/v3/test_codecs/test_vlen.py | 41 +++++++++++++++++++++---------- tests/v3/test_metadata/test_v3.py | 18 +++++++------- 3 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 858a26fc7b..28cb8af415 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -216,7 +216,7 @@ def __init__( dimension_names_parsed = parse_dimension_names(dimension_names) if fill_value is None: fill_value = default_fill_value(data_type_parsed) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy()) + fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) @@ -346,18 +346,20 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: FLOAT = np.float16 | np.float32 | np.float64 COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType COMPLEX = np.complex64 | np.complex128 - +STRING_DTYPE = Literal[DataType.string] +STRING = np.str_ +BYTES = np.bytes_ @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: BOOL_DTYPE, ) -> BOOL: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: INTEGER_DTYPE, ) -> INTEGER: ... @@ -402,7 +404,7 @@ def default_fill_value(dtype: DataType) -> str | bytes | np.generic: def parse_fill_value( fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any], + dtype: DataType, ) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: """ Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. @@ -417,8 +419,8 @@ def parse_fill_value( ---------- fill_value: Any A potential fill value. - dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE - A numpy data type that models a data type defined in the Zarr V3 specification. + dtype: DataType + A valid Zarr V3 DataType. Returns ------- @@ -426,14 +428,20 @@ def parse_fill_value( """ if fill_value is None: raise ValueError("Fill value cannot be None") - if dtype.kind == "O": - return fill_value + if dtype == DataType.string: + return np.str_(fill_value) + if dtype == DataType.bytes: + return np.bytes_(fill_value) + + # the rest are numeric types + np_dtype = dtype.to_numpy() + if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): - if dtype.type in (np.complex64, np.complex128): + if dtype in (DataType.complex64, DataType.complex128): dtype = cast(COMPLEX_DTYPE, dtype) if len(fill_value) == 2: # complex datatypes serialize to JSON arrays with two elements - return dtype.type(complex(*fill_value)) + return np_dtype.type(complex(*fill_value)) else: msg = ( f"Got an invalid fill value for complex data type {dtype}." @@ -452,7 +460,7 @@ def parse_fill_value( # fill_value != casted_value below. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - casted_value = np.dtype(dtype).type(fill_value) + casted_value = np.dtype(np_dtype).type(fill_value) except (ValueError, OverflowError, TypeError) as e: raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e # Check if the value is still representable by the dtype @@ -460,7 +468,7 @@ def parse_fill_value( pass elif fill_value in ["Infinity", "-Infinity"] and not np.isfinite(casted_value): pass - elif dtype.kind in "cf": + elif np_dtype.kind in "cf": # float comparison is not exact, especially when dtype None: - sp = StorePath(store, path="string") +def test_default_fill_values(store: Store) -> None: + a = Array.create(StorePath(store, path="string"), shape=5, chunk_shape=5, dtype=" None: + with pytest.raises(ValueError, match="At least one ArrayBytesCodec is required."): + Array.create(StorePath(store, path="a"), shape=5, chunk_shape=5, dtype=" None: """ Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) parsed = parse_fill_value(fill_value, dtype) if np.isnan(fill_value): assert np.isnan(parsed) else: - assert parsed == dtype.type(fill_value) + assert parsed == dtype.to_numpy().type(fill_value) @pytest.mark.parametrize("fill_value", ["not a valid value"]) @@ -125,7 +125,7 @@ def test_parse_fill_value_invalid_value(fill_value: Any, dtype_str: str) -> None Test that parse_fill_value(fill_value, dtype) raises ValueError for invalid values. This test excludes bool because the bool constructor takes anything. """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) with pytest.raises(ValueError): parse_fill_value(fill_value, dtype) @@ -137,11 +137,11 @@ def test_parse_fill_value_complex(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) correctly handles complex values represented as length-2 sequences """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) if isinstance(fill_value, list): - expected = dtype.type(complex(*fill_value)) + expected = dtype.to_numpy().type(complex(*fill_value)) else: - expected = dtype.type(fill_value) + expected = dtype.to_numpy().type(fill_value) assert expected == parse_fill_value(fill_value, dtype) @@ -152,7 +152,7 @@ def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> No Test that parse_fill_value(fill_value, dtype) correctly rejects sequences with length not equal to 2 """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) match = ( f"Got an invalid fill value for complex data type {dtype}." f"Expected a sequence with 2 elements, but {fill_value} has " @@ -169,7 +169,7 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) raises TypeError for invalid non-sequential types. This test excludes bool because the bool constructor takes anything. """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): parse_fill_value(fill_value, dtype) @@ -190,7 +190,7 @@ def test_parse_fill_value_invalid_type_sequence(fill_value: Any, dtype_str: str) This test excludes bool because the bool constructor takes anything, and complex because complex values can be created from length-2 sequences. """ - dtype = np.dtype(dtype_str) + dtype = DataType(dtype_str) match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}" with pytest.raises(TypeError, match=re.escape(match)): parse_fill_value(fill_value, dtype) From c6de8780d51457c2552ea969c9598d1304522179 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 19:09:41 -0400 Subject: [PATCH 15/29] tests pass but not mypy --- src/zarr/codecs/__init__.py | 5 +++-- src/zarr/core/metadata/v3.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 15c0e4c8f6..6a15b1e487 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -1,8 +1,9 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any -import numpy as np +if TYPE_CHECKING: + import numpy as np from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle from zarr.codecs.bytes import BytesCodec, Endian diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 28cb8af415..6269a858d1 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -346,10 +346,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: FLOAT = np.float16 | np.float32 | np.float64 COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType COMPLEX = np.complex64 | np.complex128 -STRING_DTYPE = Literal[DataType.string] STRING = np.str_ +BYTES_DTYPE = np.dtypes.BytesDType BYTES = np.bytes_ + @overload def parse_fill_value( fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, From 4f026dbb11bed93477b20b405a5228a3389bcae1 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 20:25:25 -0400 Subject: [PATCH 16/29] attempted to change parse_fill_value typing --- src/zarr/core/metadata/v3.py | 120 +++++++++++++++++++----------- tests/v3/test_metadata/test_v3.py | 23 +++--- 2 files changed, 85 insertions(+), 58 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 6269a858d1..fd28611b9e 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, cast, overload +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: from typing import Self @@ -14,7 +14,7 @@ from collections.abc import Iterable, Sequence from dataclasses import dataclass, field, replace from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, cast import numcodecs.abc import numpy as np @@ -29,7 +29,7 @@ from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.registry import get_codec_class -from zarr.strings import STRING_DTYPE +from zarr.strings import STRING_DTYPE as STRING_NP_DTYPE DEFAULT_DTYPE = "float64" @@ -216,7 +216,10 @@ def __init__( dimension_names_parsed = parse_dimension_names(dimension_names) if fill_value is None: fill_value = default_fill_value(data_type_parsed) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) + # we pass a string here rather than an enum to make mypy happy + fill_value_parsed = parse_fill_value( + fill_value, dtype_value=cast(ALL_DTYPES, data_type_parsed.value) + ) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) @@ -329,27 +332,38 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: return replace(self, attributes=attributes) +# enum Literals can't be used in typing, so we have to restate all of the V3 dtypes as types +# https://github.com/python/typing/issues/781 + BOOL = np.bool_ -BOOL_DTYPE = np.dtypes.BoolDType -INTEGER_DTYPE = ( - np.dtypes.Int8DType - | np.dtypes.Int16DType - | np.dtypes.Int32DType - | np.dtypes.Int64DType - | np.dtypes.UInt8DType - | np.dtypes.UInt16DType - | np.dtypes.UInt32DType - | np.dtypes.UInt64DType -) +# BOOL_DTYPE = np.dtypes.BoolDType +BOOL_DTYPE = Literal["bool"] +INTEGER_DTYPE = Literal["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] +# INTEGER_DTYPE = ( +# np.dtypes.Int8DType +# | np.dtypes.Int16DType +# | np.dtypes.Int32DType +# | np.dtypes.Int64DType +# | np.dtypes.UInt8DType +# | np.dtypes.UInt16DType +# | np.dtypes.UInt32DType +# | np.dtypes.UInt64DType +# ) INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 -FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType +# FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType +FLOAT_DTYPE = Literal["float16", "float32", "float64"] FLOAT = np.float16 | np.float32 | np.float64 -COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType +# COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType +COMPLEX_DTYPE = Literal["complex64", "complex128"] COMPLEX = np.complex64 | np.complex128 +STRING_DTYPE = Literal["string"] STRING = np.str_ -BYTES_DTYPE = np.dtypes.BytesDType +# BYTES_DTYPE = np.dtypes.BytesDType +BYTES_DTYPE = Literal["bytes"] BYTES = np.bytes_ +ALL_DTYPES = INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE + @overload def parse_fill_value( @@ -367,45 +381,50 @@ def parse_fill_value( @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: FLOAT_DTYPE, ) -> FLOAT: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, dtype: COMPLEX_DTYPE, ) -> COMPLEX: ... @overload def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, - dtype: np.dtype[Any], -) -> Any: - # This dtype[Any] is unfortunately necessary right now. - # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 - # for more details, but `dtype` here (which comes from `parse_dtype`) - # is np.dtype[Any]. - # - # If you want the specialized types rather than Any, you need to use `np.dtypes.` - # rather than np.dtypes() - ... + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: STRING_DTYPE, +) -> STRING: ... -def default_fill_value(dtype: DataType) -> str | bytes | np.generic: - if dtype == DataType.string: - return "" - elif dtype == DataType.bytes: - return b"" - else: - return dtype.to_numpy().type(0) +@overload +def parse_fill_value( + fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, + dtype: BYTES_DTYPE, +) -> BYTES: ... + + +# @overload +# def parse_fill_value( +# fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, +# dtype: np.dtype[Any], +# ) -> Any: +# # This dtype[Any] is unfortunately necessary right now. +# # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 +# # for more details, but `dtype` here (which comes from `parse_dtype`) +# # is np.dtype[Any]. +# # +# # If you want the specialized types rather than Any, you need to use `np.dtypes.` +# # rather than np.dtypes() +# ... def parse_fill_value( fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype: DataType, + dtype_value: ALL_DTYPES, ) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: """ Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. @@ -420,13 +439,15 @@ def parse_fill_value( ---------- fill_value: Any A potential fill value. - dtype: DataType + dtype_value: str A valid Zarr V3 DataType. Returns ------- A scalar instance of `dtype` """ + print("dtype_value", dtype_value) + dtype = DataType(dtype_value) if fill_value is None: raise ValueError("Fill value cannot be None") if dtype == DataType.string: @@ -439,18 +460,20 @@ def parse_fill_value( if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): if dtype in (DataType.complex64, DataType.complex128): - dtype = cast(COMPLEX_DTYPE, dtype) + # dtype = cast(np.dtypes.Complex64DType | np.dtypes.Complex128DType, np_dtype) if len(fill_value) == 2: # complex datatypes serialize to JSON arrays with two elements return np_dtype.type(complex(*fill_value)) else: msg = ( - f"Got an invalid fill value for complex data type {dtype}." + f"Got an invalid fill value for complex data type {dtype.value}." f"Expected a sequence with 2 elements, but {fill_value!r} has " f"length {len(fill_value)}." ) raise ValueError(msg) - msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype}." + msg = ( + f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype.value}." + ) raise TypeError(msg) # Cast the fill_value to the given dtype @@ -482,6 +505,15 @@ def parse_fill_value( return casted_value +def default_fill_value(dtype: DataType) -> str | bytes | np.generic: + if dtype == DataType.string: + return "" + elif dtype == DataType.bytes: + return b"" + else: + return dtype.to_numpy().type(0) + + # For type checking _bool = bool @@ -550,7 +582,7 @@ def to_numpy_shortname(self) -> str: def to_numpy(self) -> np.dtype[Any]: if self == DataType.string: - return STRING_DTYPE + return STRING_NP_DTYPE elif self == DataType.bytes: # for now always use object dtype for bytestrings # TODO: consider whether we can use fixed-width types (e.g. '|S5') instead diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index b0dc89a87d..5dad08b204 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -109,13 +109,12 @@ def test_parse_fill_value_valid(fill_value: Any, dtype_str: str) -> None: """ Test that parse_fill_value(fill_value, dtype) casts fill_value to the given dtype. """ - dtype = DataType(dtype_str) - parsed = parse_fill_value(fill_value, dtype) + parsed = parse_fill_value(fill_value, dtype_str) if np.isnan(fill_value): assert np.isnan(parsed) else: - assert parsed == dtype.to_numpy().type(fill_value) + assert parsed == DataType(dtype_str).to_numpy().type(fill_value) @pytest.mark.parametrize("fill_value", ["not a valid value"]) @@ -125,9 +124,8 @@ def test_parse_fill_value_invalid_value(fill_value: Any, dtype_str: str) -> None Test that parse_fill_value(fill_value, dtype) raises ValueError for invalid values. This test excludes bool because the bool constructor takes anything. """ - dtype = DataType(dtype_str) with pytest.raises(ValueError): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("fill_value", [[1.0, 0.0], [0, 1], complex(1, 1), np.complex64(0)]) @@ -142,7 +140,7 @@ def test_parse_fill_value_complex(fill_value: Any, dtype_str: str) -> None: expected = dtype.to_numpy().type(complex(*fill_value)) else: expected = dtype.to_numpy().type(fill_value) - assert expected == parse_fill_value(fill_value, dtype) + assert expected == parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("fill_value", [[1.0, 0.0, 3.0], [0, 1, 3], [1]]) @@ -152,14 +150,13 @@ def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> No Test that parse_fill_value(fill_value, dtype) correctly rejects sequences with length not equal to 2 """ - dtype = DataType(dtype_str) match = ( - f"Got an invalid fill value for complex data type {dtype}." + f"Got an invalid fill value for complex data type {dtype_str}." f"Expected a sequence with 2 elements, but {fill_value} has " f"length {len(fill_value)}." ) with pytest.raises(ValueError, match=re.escape(match)): - parse_fill_value(fill_value=fill_value, dtype=dtype) + parse_fill_value(fill_value=fill_value, dtype_value=dtype_str) @pytest.mark.parametrize("fill_value", [{"foo": 10}]) @@ -169,9 +166,8 @@ def test_parse_fill_value_invalid_type(fill_value: Any, dtype_str: str) -> None: Test that parse_fill_value(fill_value, dtype) raises TypeError for invalid non-sequential types. This test excludes bool because the bool constructor takes anything. """ - dtype = DataType(dtype_str) with pytest.raises(ValueError, match=r"fill value .* is not valid for dtype .*"): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize( @@ -190,10 +186,9 @@ def test_parse_fill_value_invalid_type_sequence(fill_value: Any, dtype_str: str) This test excludes bool because the bool constructor takes anything, and complex because complex values can be created from length-2 sequences. """ - dtype = DataType(dtype_str) - match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype}" + match = f"Cannot parse non-string sequence {fill_value} as a scalar with type {dtype_str}" with pytest.raises(TypeError, match=re.escape(match)): - parse_fill_value(fill_value, dtype) + parse_fill_value(fill_value, dtype_str) @pytest.mark.parametrize("chunk_grid", ["regular"]) From e427c7aea0fb756ebe9139f434a09d90d1e1a6e9 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 7 Oct 2024 20:39:48 -0400 Subject: [PATCH 17/29] restore DEFAULT_DTYPE --- src/zarr/core/metadata/v3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index fd28611b9e..2fea7fffc8 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -618,8 +618,7 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: @classmethod def parse(cls, dtype: None | DataType | Any) -> DataType: if dtype is None: - # the default dtype - return DataType.float64 + return DataType[DEFAULT_DTYPE] if isinstance(dtype, DataType): return dtype try: From 7d9d89736245c8e547c9cda0b70683de74982dfd Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 7 Oct 2024 20:57:07 -0500 Subject: [PATCH 18/29] fixup --- src/zarr/core/array.py | 4 ++- src/zarr/core/metadata/v3.py | 54 +++++++++++++----------------------- 2 files changed, 23 insertions(+), 35 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index 51003b6b16..a332761b7f 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -318,7 +318,9 @@ async def _create_v3( await ensure_no_existing_node(store_path, zarr_format=3) shape = parse_shapelike(shape) - codecs = list(codecs) if codecs is not None else [get_default_array_bytes_codec(dtype)] + codecs = ( + list(codecs) if codecs is not None else [get_default_array_bytes_codec(np.dtype(dtype))] + ) if chunk_key_encoding is None: chunk_key_encoding = ("default", "/") diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 2fea7fffc8..34d49de05e 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -218,7 +218,7 @@ def __init__( fill_value = default_fill_value(data_type_parsed) # we pass a string here rather than an enum to make mypy happy fill_value_parsed = parse_fill_value( - fill_value, dtype_value=cast(ALL_DTYPES, data_type_parsed.value) + fill_value, dtype=cast(ALL_DTYPES, data_type_parsed.value) ) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) @@ -362,7 +362,7 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: BYTES_DTYPE = Literal["bytes"] BYTES = np.bytes_ -ALL_DTYPES = INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE +ALL_DTYPES = BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | STRING_DTYPE | BYTES_DTYPE @overload @@ -407,25 +407,10 @@ def parse_fill_value( ) -> BYTES: ... -# @overload -# def parse_fill_value( -# fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None, -# dtype: np.dtype[Any], -# ) -> Any: -# # This dtype[Any] is unfortunately necessary right now. -# # See https://github.com/zarr-developers/zarr-python/issues/2131#issuecomment-2318010899 -# # for more details, but `dtype` here (which comes from `parse_dtype`) -# # is np.dtype[Any]. -# # -# # If you want the specialized types rather than Any, you need to use `np.dtypes.` -# # rather than np.dtypes() -# ... - - def parse_fill_value( - fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool, - dtype_value: ALL_DTYPES, -) -> BOOL | INTEGER | FLOAT | COMPLEX | Any: + fill_value: Any, + dtype: ALL_DTYPES, +) -> Any: """ Parse `fill_value`, a potential fill value, into an instance of `dtype`, a data type. If `fill_value` is `None`, then this function will return the result of casting the value 0 @@ -446,34 +431,32 @@ def parse_fill_value( ------- A scalar instance of `dtype` """ - print("dtype_value", dtype_value) - dtype = DataType(dtype_value) + print("dtype_value", dtype) + data_type = DataType(dtype) if fill_value is None: raise ValueError("Fill value cannot be None") - if dtype == DataType.string: + if data_type == DataType.string: return np.str_(fill_value) - if dtype == DataType.bytes: + if data_type == DataType.bytes: return np.bytes_(fill_value) # the rest are numeric types - np_dtype = dtype.to_numpy() + np_dtype = data_type.to_numpy() if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): - if dtype in (DataType.complex64, DataType.complex128): + if data_type in (DataType.complex64, DataType.complex128): # dtype = cast(np.dtypes.Complex64DType | np.dtypes.Complex128DType, np_dtype) if len(fill_value) == 2: # complex datatypes serialize to JSON arrays with two elements return np_dtype.type(complex(*fill_value)) else: msg = ( - f"Got an invalid fill value for complex data type {dtype.value}." + f"Got an invalid fill value for complex data type {data_type.value}." f"Expected a sequence with 2 elements, but {fill_value!r} has " f"length {len(fill_value)}." ) raise ValueError(msg) - msg = ( - f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {dtype.value}." - ) + msg = f"Cannot parse non-string sequence {fill_value!r} as a scalar with type {data_type.value}." raise TypeError(msg) # Cast the fill_value to the given dtype @@ -486,7 +469,7 @@ def parse_fill_value( warnings.filterwarnings("ignore", category=DeprecationWarning) casted_value = np.dtype(np_dtype).type(fill_value) except (ValueError, OverflowError, TypeError) as e: - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") from e + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") from e # Check if the value is still representable by the dtype if fill_value == "NaN" and np.isnan(casted_value): pass @@ -497,15 +480,18 @@ def parse_fill_value( # so we us np.isclose for this comparison. # this also allows us to compare nan fill_values if not np.isclose(fill_value, casted_value, equal_nan=True): - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") else: if fill_value != casted_value: - raise ValueError(f"fill value {fill_value!r} is not valid for dtype {dtype}") + raise ValueError(f"fill value {fill_value!r} is not valid for dtype {data_type}") return casted_value -def default_fill_value(dtype: DataType) -> str | bytes | np.generic: +def default_fill_value(dtype: DataType) -> Any: + # TODO: the static types could maybe be narrowed here. + # mypy knows that np.dtype("int64").type(0) is an int64. + # so maybe DataType needs to be generic? if dtype == DataType.string: return "" elif dtype == DataType.bytes: From 0c21994190c9418963b6010f2ee8624c74c36382 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Mon, 7 Oct 2024 20:58:47 -0500 Subject: [PATCH 19/29] docstring --- src/zarr/core/metadata/v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 34d49de05e..b65e27bea3 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -424,7 +424,7 @@ def parse_fill_value( ---------- fill_value: Any A potential fill value. - dtype_value: str + dtype: str A valid Zarr V3 DataType. Returns From c12ac4130275435ec54ae1d1d94be377832cc10d Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 8 Oct 2024 06:45:23 -0500 Subject: [PATCH 20/29] update test --- tests/v3/test_metadata/test_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 5dad08b204..ffc98e2303 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -156,7 +156,7 @@ def test_parse_fill_value_complex_invalid(fill_value: Any, dtype_str: str) -> No f"length {len(fill_value)}." ) with pytest.raises(ValueError, match=re.escape(match)): - parse_fill_value(fill_value=fill_value, dtype_value=dtype_str) + parse_fill_value(fill_value=fill_value, dtype=dtype_str) @pytest.mark.parametrize("fill_value", [{"foo": 10}]) From 3aeea1e1a405862de544bf1b41445b7f4b5062b5 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 09:11:56 -0400 Subject: [PATCH 21/29] add better DataType tests --- src/zarr/core/metadata/v3.py | 23 +++++++++++++---------- tests/v3/test_metadata/test_v3.py | 14 ++++++++++++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index b65e27bea3..d779825b60 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -431,7 +431,6 @@ def parse_fill_value( ------- A scalar instance of `dtype` """ - print("dtype_value", dtype) data_type = DataType(dtype) if fill_value is None: raise ValueError("Fill value cannot be None") @@ -488,10 +487,7 @@ def parse_fill_value( return casted_value -def default_fill_value(dtype: DataType) -> Any: - # TODO: the static types could maybe be narrowed here. - # mypy knows that np.dtype("int64").type(0) is an int64. - # so maybe DataType needs to be generic? +def default_fill_value(dtype: DataType) -> str | bytes | np.generic: if dtype == DataType.string: return "" elif dtype == DataType.bytes: @@ -523,7 +519,7 @@ class DataType(Enum): bytes = "bytes" @property - def byte_count(self) -> int: + def byte_count(self) -> None | int: data_type_byte_counts = { DataType.bool: 1, DataType.int8: 1, @@ -540,12 +536,15 @@ def byte_count(self) -> int: DataType.complex64: 8, DataType.complex128: 16, } - return data_type_byte_counts[self] + try: + return data_type_byte_counts[self] + except KeyError: + # string and bytes have variable length + return None @property def has_endianness(self) -> _bool: - # This might change in the future, e.g. for a complex with 2 8-bit floats - return self.byte_count != 1 + return self.byte_count is not None and self.byte_count != 1 def to_numpy_shortname(self) -> str: data_type_to_numpy = { @@ -566,7 +565,11 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] - def to_numpy(self) -> np.dtype[Any]: + def to_numpy(self) -> np.dtype[np.generic]: + # note: it is not possible to round trip DataType <-> np.dtype + # due to the fact that DataType.string and DataType.bytes both + # generally return np.dtype("O") from this function, even though + # they can originate as fixed-length types (e.g. " None: elif fill_value == "-Infinity": assert np.isneginf(m.fill_value) assert d["fill_value"] == "-Infinity" + + +@pytest.mark.parametrize("dtype_str", dtypes) +def test_dtypes(dtype_str: str) -> None: + dt = DataType(dtype_str) + np_dtype = dt.to_numpy() + if dtype_str not in vlen_dtypes: + # we can round trip "normal" dtypes + assert dt == DataType.from_numpy(np_dtype) + assert dt.byte_count == np_dtype.itemsize + assert dt.has_endianness == (dt.byte_count > 1) + else: + # return type for vlen types may vary depending on numpy version + assert dt.byte_count is None From cae70557b750bdc0e0a0c98a12fcf1cf6cdf10a1 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 11:21:11 -0400 Subject: [PATCH 22/29] more progress on typing; still not passing mypy --- src/zarr/core/metadata/v3.py | 18 +-------- src/zarr/strings.py | 71 +++++++++++++++++++++++++++++++----- tests/test_strings.py | 35 ++++++++++++++++++ 3 files changed, 98 insertions(+), 26 deletions(-) create mode 100644 tests/test_strings.py diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index d779825b60..088890ff8d 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -335,30 +335,16 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self: # enum Literals can't be used in typing, so we have to restate all of the V3 dtypes as types # https://github.com/python/typing/issues/781 -BOOL = np.bool_ -# BOOL_DTYPE = np.dtypes.BoolDType BOOL_DTYPE = Literal["bool"] +BOOL = np.bool_ INTEGER_DTYPE = Literal["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] -# INTEGER_DTYPE = ( -# np.dtypes.Int8DType -# | np.dtypes.Int16DType -# | np.dtypes.Int32DType -# | np.dtypes.Int64DType -# | np.dtypes.UInt8DType -# | np.dtypes.UInt16DType -# | np.dtypes.UInt32DType -# | np.dtypes.UInt64DType -# ) INTEGER = np.int8 | np.int16 | np.int32 | np.int64 | np.uint8 | np.uint16 | np.uint32 | np.uint64 -# FLOAT_DTYPE = np.dtypes.Float16DType | np.dtypes.Float32DType | np.dtypes.Float64DType FLOAT_DTYPE = Literal["float16", "float32", "float64"] FLOAT = np.float16 | np.float32 | np.float64 -# COMPLEX_DTYPE = np.dtypes.Complex64DType | np.dtypes.Complex128DType COMPLEX_DTYPE = Literal["complex64", "complex128"] COMPLEX = np.complex64 | np.complex128 STRING_DTYPE = Literal["string"] STRING = np.str_ -# BYTES_DTYPE = np.dtypes.BytesDType BYTES_DTYPE = Literal["bytes"] BYTES = np.bytes_ @@ -565,7 +551,7 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] - def to_numpy(self) -> np.dtype[np.generic]: + def to_numpy(self) -> np.dtypes.StringDType | np.dtypes.ObjectDType | np.dtype[np.generic]: # note: it is not possible to round trip DataType <-> np.dtype # due to the fact that DataType.string and DataType.bytes both # generally return np.dtype("O") from this function, even though diff --git a/src/zarr/strings.py b/src/zarr/strings.py index c3dd1a3395..0b7f9aeb6b 100644 --- a/src/zarr/strings.py +++ b/src/zarr/strings.py @@ -1,36 +1,87 @@ -from typing import Any +"""This module contains utilities for working with string arrays across +different versions of Numpy. +""" + +from typing import Any, cast from warnings import warn import numpy as np +# STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays +# when reading data back from Zarr. +# Any valid string-like datatype should be fine for *setting* data. + +STRING_DTYPE: np.dtypes.StringDType | np.dtypes.ObjectDType +NUMPY_SUPPORTS_VLEN_STRING: bool + + +def cast_array( + data: np.ndarray[Any, np.dtype[Any]], +) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + raise NotImplementedError + + try: - STRING_DTYPE = np.dtype("T") + # this new vlen string dtype was added in NumPy 2.0 + STRING_DTYPE = np.dtypes.StringDType() NUMPY_SUPPORTS_VLEN_STRING = True -except TypeError: - STRING_DTYPE = np.dtype("object") + + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + out = data.astype(STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.StringDType], out) + +except AttributeError: + # if not available, we fall back on an object array of strings, as in Zarr < 3 + STRING_DTYPE = np.dtypes.ObjectDType() NUMPY_SUPPORTS_VLEN_STRING = False + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + out = data.astype(STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) + def cast_to_string_dtype( data: np.ndarray[Any, np.dtype[Any]], safe: bool = False -) -> np.ndarray[Any, np.dtype[Any]]: +) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + """Take any data and attempt to cast to to our preferred string dtype. + + data : np.ndarray + The data to cast + + safe : bool + If True, do not issue a warning if the data is cast from object to string dtype. + + """ if np.issubdtype(data.dtype, np.str_): - return data + # legacy fixed-width string type (e.g. "= 2.", stacklevel=2, ) - return out + return cast_array(data) raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") diff --git a/tests/test_strings.py b/tests/test_strings.py new file mode 100644 index 0000000000..88cde18a3e --- /dev/null +++ b/tests/test_strings.py @@ -0,0 +1,35 @@ +"""Tests for the strings module.""" + +import numpy as np +import pytest + +from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING, STRING_DTYPE, cast_to_string_dtype + + +def test_string_defaults() -> None: + if NUMPY_SUPPORTS_VLEN_STRING: + assert STRING_DTYPE == np.dtypes.StringDType() + else: + assert STRING_DTYPE == np.dtypes.ObjectDType() + + +def test_cast_to_string_dtype() -> None: + d1 = np.array(["a", "b", "c"]) + assert d1.dtype == np.dtype(" Date: Tue, 8 Oct 2024 11:24:44 -0400 Subject: [PATCH 23/29] fix typing yay! --- src/zarr/core/metadata/v3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 088890ff8d..b181e3b032 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -427,6 +427,7 @@ def parse_fill_value( # the rest are numeric types np_dtype = data_type.to_numpy() + np_dtype = cast(np.dtype[np.generic], np_dtype) if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): if data_type in (DataType.complex64, DataType.complex128): @@ -479,7 +480,9 @@ def default_fill_value(dtype: DataType) -> str | bytes | np.generic: elif dtype == DataType.bytes: return b"" else: - return dtype.to_numpy().type(0) + np_dtype = dtype.to_numpy() + np_dtype = cast(np.dtype[np.generic], np_dtype) + return np_dtype.type(0) # For type checking From 6714bad58640db76e97b8adcb00739d3faec7c05 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 11:37:04 -0400 Subject: [PATCH 24/29] make types work with numpy <, 2 --- src/zarr/strings.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zarr/strings.py b/src/zarr/strings.py index 0b7f9aeb6b..06d2489442 100644 --- a/src/zarr/strings.py +++ b/src/zarr/strings.py @@ -2,7 +2,7 @@ different versions of Numpy. """ -from typing import Any, cast +from typing import Any, Union, cast from warnings import warn import numpy as np @@ -11,13 +11,13 @@ # when reading data back from Zarr. # Any valid string-like datatype should be fine for *setting* data. -STRING_DTYPE: np.dtypes.StringDType | np.dtypes.ObjectDType +STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] NUMPY_SUPPORTS_VLEN_STRING: bool def cast_array( data: np.ndarray[Any, np.dtype[Any]], -) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: raise NotImplementedError @@ -39,14 +39,14 @@ def cast_array( def cast_array( data: np.ndarray[Any, np.dtype[Any]], - ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: out = data.astype(STRING_DTYPE, copy=False) return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) def cast_to_string_dtype( data: np.ndarray[Any, np.dtype[Any]], safe: bool = False -) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: """Take any data and attempt to cast to to our preferred string dtype. data : np.ndarray From 2edf3b80accd846801e0e9451de0b2408b50c9a3 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 14:20:43 -0400 Subject: [PATCH 25/29] Apply suggestions from code review Co-authored-by: Joe Hamman --- src/zarr/codecs/__init__.py | 2 +- src/zarr/codecs/vlen_utf8.py | 4 ++-- src/zarr/core/buffer/core.py | 2 -- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/zarr/codecs/__init__.py b/src/zarr/codecs/__init__.py index 6a15b1e487..ac647d7863 100644 --- a/src/zarr/codecs/__init__.py +++ b/src/zarr/codecs/__init__.py @@ -34,7 +34,7 @@ ] -def get_default_array_bytes_codec( +def _get_default_array_bytes_codec( np_dtype: np.dtype[Any], ) -> BytesCodec | VLenUTF8Codec | VLenBytesCodec: dtype = DataType.from_numpy(np_dtype) diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index dd21f18999..b207cdcca7 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -19,8 +19,8 @@ # can use a global because there are no parameters -vlen_utf8_codec = VLenUTF8() -vlen_bytes_codec = VLenBytes() +_vlen_utf8_codec = VLenUTF8() +_vlen_bytes_codec = VLenBytes() @dataclass(frozen=True) diff --git a/src/zarr/core/buffer/core.py b/src/zarr/core/buffer/core.py index b520e21ee3..1fbf58c618 100644 --- a/src/zarr/core/buffer/core.py +++ b/src/zarr/core/buffer/core.py @@ -313,7 +313,6 @@ class NDBuffer: """ def __init__(self, array: NDArrayLike) -> None: - # assert array.dtype != object self._data = array @classmethod @@ -466,7 +465,6 @@ def all_equal(self, other: Any, equal_nan: bool = True) -> bool: # Handle None fill_value for Zarr V2 return False # use array_equal to obtain equal_nan=True functionality - # Note from Ryan: doesn't this lead to a huge amount of unnecessary memory allocation on every single chunk? # Since fill-value is a scalar, isn't there a faster path than allocating a new array for fill value # every single time we have to write data? _data, other = np.broadcast_arrays(self._data, other) From 12a0d65da54bfd434a0569e62a6ed720dc1e1360 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 14:22:30 -0400 Subject: [PATCH 26/29] Apply suggestions from code review Co-authored-by: Joe Hamman --- src/zarr/core/metadata/v3.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index b181e3b032..449f91058e 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -426,12 +426,10 @@ def parse_fill_value( return np.bytes_(fill_value) # the rest are numeric types - np_dtype = data_type.to_numpy() - np_dtype = cast(np.dtype[np.generic], np_dtype) + np_dtype = cast(np.dtype[np.generic], data_type.to_numpy()) if isinstance(fill_value, Sequence) and not isinstance(fill_value, str): if data_type in (DataType.complex64, DataType.complex128): - # dtype = cast(np.dtypes.Complex64DType | np.dtypes.Complex128DType, np_dtype) if len(fill_value) == 2: # complex datatypes serialize to JSON arrays with two elements return np_dtype.type(complex(*fill_value)) From 7ba70771bd05ec61800459fe16696381acb1957d Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 14:34:59 -0400 Subject: [PATCH 27/29] apply Joe's suggestions --- src/zarr/codecs/vlen_utf8.py | 10 ++-- src/zarr/core/array.py | 6 ++- src/zarr/core/metadata/v3.py | 2 +- src/zarr/strings.py | 87 ------------------------------- tests/test_strings.py | 2 +- tests/v3/test_array.py | 38 +++++++++++++- tests/v3/test_codecs/test_vlen.py | 37 +------------ 7 files changed, 50 insertions(+), 132 deletions(-) delete mode 100644 src/zarr/strings.py diff --git a/src/zarr/codecs/vlen_utf8.py b/src/zarr/codecs/vlen_utf8.py index b207cdcca7..43544e0809 100644 --- a/src/zarr/codecs/vlen_utf8.py +++ b/src/zarr/codecs/vlen_utf8.py @@ -9,8 +9,8 @@ from zarr.abc.codec import ArrayBytesCodec from zarr.core.buffer import Buffer, NDBuffer from zarr.core.common import JSON, parse_named_configuration +from zarr.core.strings import cast_to_string_dtype from zarr.registry import register_codec -from zarr.strings import cast_to_string_dtype if TYPE_CHECKING: from typing import Self @@ -47,7 +47,7 @@ async def _decode_single( assert isinstance(chunk_bytes, Buffer) raw_bytes = chunk_bytes.as_array_like() - decoded = vlen_utf8_codec.decode(raw_bytes) + decoded = _vlen_utf8_codec.decode(raw_bytes) assert decoded.dtype == np.object_ decoded.shape = chunk_spec.shape # coming out of the code, we know this is safe, so don't issue a warning @@ -61,7 +61,7 @@ async def _encode_single( ) -> Buffer | None: assert isinstance(chunk_array, NDBuffer) return chunk_spec.prototype.buffer.from_bytes( - vlen_utf8_codec.encode(chunk_array.as_numpy_array()) + _vlen_utf8_codec.encode(chunk_array.as_numpy_array()) ) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: @@ -93,7 +93,7 @@ async def _decode_single( assert isinstance(chunk_bytes, Buffer) raw_bytes = chunk_bytes.as_array_like() - decoded = vlen_bytes_codec.decode(raw_bytes) + decoded = _vlen_bytes_codec.decode(raw_bytes) assert decoded.dtype == np.object_ decoded.shape = chunk_spec.shape return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded) @@ -105,7 +105,7 @@ async def _encode_single( ) -> Buffer | None: assert isinstance(chunk_array, NDBuffer) return chunk_spec.prototype.buffer.from_bytes( - vlen_bytes_codec.encode(chunk_array.as_numpy_array()) + _vlen_bytes_codec.encode(chunk_array.as_numpy_array()) ) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index a332761b7f..9f5591ce1e 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -11,7 +11,7 @@ from zarr._compat import _deprecate_positional_args from zarr.abc.store import Store, set_or_delete -from zarr.codecs import get_default_array_bytes_codec +from zarr.codecs import _get_default_array_bytes_codec from zarr.codecs._v2 import V2Compressor, V2Filters from zarr.core.attributes import Attributes from zarr.core.buffer import ( @@ -319,7 +319,9 @@ async def _create_v3( shape = parse_shapelike(shape) codecs = ( - list(codecs) if codecs is not None else [get_default_array_bytes_codec(np.dtype(dtype))] + list(codecs) + if codecs is not None + else [_get_default_array_bytes_codec(np.dtype(dtype))] ) if chunk_key_encoding is None: diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 449f91058e..006b6cf84b 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -28,8 +28,8 @@ from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes +from zarr.core.strings import STRING_DTYPE as STRING_NP_DTYPE from zarr.registry import get_codec_class -from zarr.strings import STRING_DTYPE as STRING_NP_DTYPE DEFAULT_DTYPE = "float64" diff --git a/src/zarr/strings.py b/src/zarr/strings.py deleted file mode 100644 index 06d2489442..0000000000 --- a/src/zarr/strings.py +++ /dev/null @@ -1,87 +0,0 @@ -"""This module contains utilities for working with string arrays across -different versions of Numpy. -""" - -from typing import Any, Union, cast -from warnings import warn - -import numpy as np - -# STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays -# when reading data back from Zarr. -# Any valid string-like datatype should be fine for *setting* data. - -STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] -NUMPY_SUPPORTS_VLEN_STRING: bool - - -def cast_array( - data: np.ndarray[Any, np.dtype[Any]], -) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - raise NotImplementedError - - -try: - # this new vlen string dtype was added in NumPy 2.0 - STRING_DTYPE = np.dtypes.StringDType() - NUMPY_SUPPORTS_VLEN_STRING = True - - def cast_array( - data: np.ndarray[Any, np.dtype[Any]], - ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: - out = data.astype(STRING_DTYPE, copy=False) - return cast(np.ndarray[Any, np.dtypes.StringDType], out) - -except AttributeError: - # if not available, we fall back on an object array of strings, as in Zarr < 3 - STRING_DTYPE = np.dtypes.ObjectDType() - NUMPY_SUPPORTS_VLEN_STRING = False - - def cast_array( - data: np.ndarray[Any, np.dtype[Any]], - ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - out = data.astype(STRING_DTYPE, copy=False) - return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) - - -def cast_to_string_dtype( - data: np.ndarray[Any, np.dtype[Any]], safe: bool = False -) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - """Take any data and attempt to cast to to our preferred string dtype. - - data : np.ndarray - The data to cast - - safe : bool - If True, do not issue a warning if the data is cast from object to string dtype. - - """ - if np.issubdtype(data.dtype, np.str_): - # legacy fixed-width string type (e.g. "= 2.", - stacklevel=2, - ) - return cast_array(data) - raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") diff --git a/tests/test_strings.py b/tests/test_strings.py index 88cde18a3e..709b54e36d 100644 --- a/tests/test_strings.py +++ b/tests/test_strings.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING, STRING_DTYPE, cast_to_string_dtype +from zarr.core.strings import NUMPY_SUPPORTS_VLEN_STRING, STRING_DTYPE, cast_to_string_dtype def test_string_defaults() -> None: diff --git a/tests/v3/test_array.py b/tests/v3/test_array.py index 291d284483..04adb2a224 100644 --- a/tests/v3/test_array.py +++ b/tests/v3/test_array.py @@ -8,7 +8,7 @@ import zarr.api.asynchronous import zarr.storage from zarr import Array, AsyncArray, Group -from zarr.codecs.bytes import BytesCodec +from zarr.codecs import BytesCodec, VLenBytesCodec from zarr.core.array import chunks_initialized from zarr.core.buffer.cpu import NDBuffer from zarr.core.common import JSON, ZarrFormat @@ -370,3 +370,39 @@ def test_chunks_initialized(test_cls: type[Array] | type[AsyncArray]) -> None: expected = sorted(keys) assert observed == expected + + +def test_default_fill_values() -> None: + a = Array.create(MemoryStore({}, mode="w"), shape=5, chunk_shape=5, dtype=" None: + with pytest.raises(ValueError, match="At least one ArrayBytesCodec is required."): + Array.create(MemoryStore({}, mode="w"), shape=5, chunk_shape=5, dtype=" None: - a = Array.create(StorePath(store, path="string"), shape=5, chunk_shape=5, dtype=" None: - with pytest.raises(ValueError, match="At least one ArrayBytesCodec is required."): - Array.create(StorePath(store, path="a"), shape=5, chunk_shape=5, dtype=" Date: Tue, 8 Oct 2024 14:39:37 -0400 Subject: [PATCH 28/29] add missing module --- src/zarr/core/strings.py | 87 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 src/zarr/core/strings.py diff --git a/src/zarr/core/strings.py b/src/zarr/core/strings.py new file mode 100644 index 0000000000..06d2489442 --- /dev/null +++ b/src/zarr/core/strings.py @@ -0,0 +1,87 @@ +"""This module contains utilities for working with string arrays across +different versions of Numpy. +""" + +from typing import Any, Union, cast +from warnings import warn + +import numpy as np + +# STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays +# when reading data back from Zarr. +# Any valid string-like datatype should be fine for *setting* data. + +STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] +NUMPY_SUPPORTS_VLEN_STRING: bool + + +def cast_array( + data: np.ndarray[Any, np.dtype[Any]], +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + raise NotImplementedError + + +try: + # this new vlen string dtype was added in NumPy 2.0 + STRING_DTYPE = np.dtypes.StringDType() + NUMPY_SUPPORTS_VLEN_STRING = True + + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: + out = data.astype(STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.StringDType], out) + +except AttributeError: + # if not available, we fall back on an object array of strings, as in Zarr < 3 + STRING_DTYPE = np.dtypes.ObjectDType() + NUMPY_SUPPORTS_VLEN_STRING = False + + def cast_array( + data: np.ndarray[Any, np.dtype[Any]], + ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + out = data.astype(STRING_DTYPE, copy=False) + return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) + + +def cast_to_string_dtype( + data: np.ndarray[Any, np.dtype[Any]], safe: bool = False +) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: + """Take any data and attempt to cast to to our preferred string dtype. + + data : np.ndarray + The data to cast + + safe : bool + If True, do not issue a warning if the data is cast from object to string dtype. + + """ + if np.issubdtype(data.dtype, np.str_): + # legacy fixed-width string type (e.g. "= 2.", + stacklevel=2, + ) + return cast_array(data) + raise ValueError(f"Cannot cast dtype {data.dtype} to string dtype") From ba0f0936a47c4a6e25e7b5b3b97b50c2b0a7d61b Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Tue, 8 Oct 2024 15:00:09 -0400 Subject: [PATCH 29/29] make _STRING_DTYPE private to try to make sphinx happy --- src/zarr/core/metadata/v3.py | 2 +- src/zarr/core/strings.py | 24 ++++++++++++------------ tests/test_strings.py | 22 +++++++++++----------- tests/v3/test_codecs/test_vlen.py | 4 ++-- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index 006b6cf84b..47c6106bfe 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -28,7 +28,7 @@ from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike from zarr.core.config import config from zarr.core.metadata.common import ArrayMetadata, parse_attributes -from zarr.core.strings import STRING_DTYPE as STRING_NP_DTYPE +from zarr.core.strings import _STRING_DTYPE as STRING_NP_DTYPE from zarr.registry import get_codec_class DEFAULT_DTYPE = "float64" diff --git a/src/zarr/core/strings.py b/src/zarr/core/strings.py index 06d2489442..9ec391c04a 100644 --- a/src/zarr/core/strings.py +++ b/src/zarr/core/strings.py @@ -7,12 +7,12 @@ import numpy as np -# STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays +# _STRING_DTYPE is the in-memory datatype that will be used for V3 string arrays # when reading data back from Zarr. # Any valid string-like datatype should be fine for *setting* data. -STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] -NUMPY_SUPPORTS_VLEN_STRING: bool +_STRING_DTYPE: Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"] +_NUMPY_SUPPORTS_VLEN_STRING: bool def cast_array( @@ -23,24 +23,24 @@ def cast_array( try: # this new vlen string dtype was added in NumPy 2.0 - STRING_DTYPE = np.dtypes.StringDType() - NUMPY_SUPPORTS_VLEN_STRING = True + _STRING_DTYPE = np.dtypes.StringDType() + _NUMPY_SUPPORTS_VLEN_STRING = True def cast_array( data: np.ndarray[Any, np.dtype[Any]], ) -> np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType]: - out = data.astype(STRING_DTYPE, copy=False) + out = data.astype(_STRING_DTYPE, copy=False) return cast(np.ndarray[Any, np.dtypes.StringDType], out) except AttributeError: # if not available, we fall back on an object array of strings, as in Zarr < 3 - STRING_DTYPE = np.dtypes.ObjectDType() - NUMPY_SUPPORTS_VLEN_STRING = False + _STRING_DTYPE = np.dtypes.ObjectDType() + _NUMPY_SUPPORTS_VLEN_STRING = False def cast_array( data: np.ndarray[Any, np.dtype[Any]], ) -> np.ndarray[Any, Union["np.dtypes.StringDType", "np.dtypes.ObjectDType"]]: - out = data.astype(STRING_DTYPE, copy=False) + out = data.astype(_STRING_DTYPE, copy=False) return cast(np.ndarray[Any, np.dtypes.ObjectDType], out) @@ -61,13 +61,13 @@ def cast_to_string_dtype( return cast_array(data) # out = data.astype(STRING_DTYPE, copy=False) # return cast(np.ndarray[Any, np.dtypes.StringDType | np.dtypes.ObjectDType], out) - if NUMPY_SUPPORTS_VLEN_STRING: - if np.issubdtype(data.dtype, STRING_DTYPE): + if _NUMPY_SUPPORTS_VLEN_STRING: + if np.issubdtype(data.dtype, _STRING_DTYPE): # already a valid string variable length string dtype return cast_array(data) if np.issubdtype(data.dtype, np.object_): # object arrays require more careful handling - if NUMPY_SUPPORTS_VLEN_STRING: + if _NUMPY_SUPPORTS_VLEN_STRING: try: # cast to variable-length string dtype, fail if object contains non-string data # mypy says "error: Unexpected keyword argument "coerce" for "StringDType" [call-arg]" diff --git a/tests/test_strings.py b/tests/test_strings.py index 709b54e36d..dca0570a25 100644 --- a/tests/test_strings.py +++ b/tests/test_strings.py @@ -3,33 +3,33 @@ import numpy as np import pytest -from zarr.core.strings import NUMPY_SUPPORTS_VLEN_STRING, STRING_DTYPE, cast_to_string_dtype +from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING, _STRING_DTYPE, cast_to_string_dtype def test_string_defaults() -> None: - if NUMPY_SUPPORTS_VLEN_STRING: - assert STRING_DTYPE == np.dtypes.StringDType() + if _NUMPY_SUPPORTS_VLEN_STRING: + assert _STRING_DTYPE == np.dtypes.StringDType() else: - assert STRING_DTYPE == np.dtypes.ObjectDType() + assert _STRING_DTYPE == np.dtypes.ObjectDType() def test_cast_to_string_dtype() -> None: d1 = np.array(["a", "b", "c"]) assert d1.dtype == np.dtype("