diff --git a/src/zarr/core/metadata/v3.py b/src/zarr/core/metadata/v3.py index f0c6dc628..692f77856 100644 --- a/src/zarr/core/metadata/v3.py +++ b/src/zarr/core/metadata/v3.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from typing import Self - import numpy.typing as npt - from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_grids import ChunkGrid from zarr.core.common import JSON, ChunkCoords @@ -20,6 +18,7 @@ import numcodecs.abc import numpy as np +import numpy.typing as npt from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec from zarr.core.array_spec import ArraySpec @@ -31,6 +30,8 @@ from zarr.core.metadata.common import ArrayMetadata, parse_attributes from zarr.registry import get_codec_class +DEFAULT_DTYPE = "float64" + def parse_zarr_format(data: object) -> Literal[3]: if data == 3: @@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any: @dataclass(frozen=True, kw_only=True) class ArrayV3Metadata(ArrayMetadata): shape: ChunkCoords - data_type: np.dtype[Any] + data_type: DataType chunk_grid: ChunkGrid chunk_key_encoding: ChunkKeyEncoding fill_value: Any @@ -167,7 +168,7 @@ def __init__( self, *, shape: Iterable[int], - data_type: npt.DTypeLike, + data_type: npt.DTypeLike | DataType, chunk_grid: dict[str, JSON] | ChunkGrid, chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding, fill_value: Any, @@ -180,18 +181,18 @@ def __init__( Because the class is a frozen dataclass, we set attributes using object.__setattr__ """ shape_parsed = parse_shapelike(shape) - data_type_parsed = parse_dtype(data_type) + data_type_parsed = DataType.parse(data_type) chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid) chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding) dimension_names_parsed = parse_dimension_names(dimension_names) - fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed) + fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy()) attributes_parsed = parse_attributes(attributes) codecs_parsed_partial = parse_codecs(codecs) storage_transformers_parsed = parse_storage_transformers(storage_transformers) array_spec = ArraySpec( shape=shape_parsed, - dtype=data_type_parsed, + dtype=data_type_parsed.to_numpy(), fill_value=fill_value_parsed, order="C", # TODO: order is not needed here. prototype=default_buffer_prototype(), # TODO: prototype is not needed here. @@ -224,11 +225,14 @@ def _validate_metadata(self) -> None: if self.fill_value is None: raise ValueError("`fill_value` is required.") for codec in self.codecs: - codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid) + codec.validate( + shape=self.shape, dtype=self.data_type.to_numpy(), chunk_grid=self.chunk_grid + ) @property def dtype(self) -> np.dtype[Any]: - return self.data_type + """Interpret Zarr dtype as NumPy dtype""" + return self.data_type.to_numpy() @property def ndim(self) -> int: @@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: _ = parse_node_type_array(_data.pop("node_type")) # check that the data_type attribute is valid - _ = DataType(_data["data_type"]) + data_type = DataType.parse(_data.pop("data_type")) # dimension_names key is optional, normalize missing to `None` _data["dimension_names"] = _data.pop("dimension_names", None) # attributes key is optional, normalize missing to `None` _data["attributes"] = _data.pop("attributes", None) - return cls(**_data) # type: ignore[arg-type] + return cls(**_data, data_type=data_type) # type: ignore[arg-type] def to_dict(self) -> dict[str, JSON]: out_dict = super().to_dict() @@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str: } return data_type_to_numpy[self] + def to_numpy(self) -> np.dtype[Any]: + return np.dtype(self.to_numpy_shortname()) + @classmethod - def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: + def from_numpy(cls, dtype: np.dtype[Any]) -> DataType: dtype_to_data_type = { "|b1": "bool", "bool": "bool", @@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType: } return DataType[dtype_to_data_type[dtype.str]] - -def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]: - try: - dtype = np.dtype(data) - except (ValueError, TypeError) as e: - raise ValueError(f"Invalid V3 data_type: {data}") from e - # check that this is a valid v3 data_type - try: - _ = DataType.from_dtype(dtype) - except KeyError as e: - raise ValueError(f"Invalid V3 data_type: {dtype}") from e - - return dtype + @classmethod + def parse(cls, dtype: None | DataType | Any) -> DataType: + if dtype is None: + # the default dtype + return DataType[DEFAULT_DTYPE] + if isinstance(dtype, DataType): + return dtype + else: + try: + dtype = np.dtype(dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + # check that this is a valid v3 data_type + try: + data_type = DataType.from_numpy(dtype) + except KeyError as e: + raise ValueError(f"Invalid V3 data_type: {dtype}") from e + return data_type diff --git a/tests/v3/test_metadata/test_v3.py b/tests/v3/test_metadata/test_v3.py index 71dc917c3..534ef61d0 100644 --- a/tests/v3/test_metadata/test_v3.py +++ b/tests/v3/test_metadata/test_v3.py @@ -7,7 +7,7 @@ from zarr.codecs.bytes import BytesCodec from zarr.core.buffer import default_buffer_prototype from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding -from zarr.core.metadata.v3 import ArrayV3Metadata +from zarr.core.metadata.v3 import ArrayV3Metadata, DataType if TYPE_CHECKING: from collections.abc import Sequence @@ -22,7 +22,6 @@ from zarr.core.metadata.v3 import ( parse_dimension_names, - parse_dtype, parse_fill_value, parse_zarr_format, ) @@ -209,7 +208,7 @@ def test_metadata_to_dict( storage_transformers: None | tuple[dict[str, JSON]], ) -> None: shape = (1, 2, 3) - data_type = "uint8" + data_type = DataType.uint8 if chunk_grid == "regular": cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}} @@ -290,7 +289,7 @@ def test_metadata_to_dict( # assert result["fill_value"] == fill_value -async def test_invalid_dtype_raises() -> None: +def test_invalid_dtype_raises() -> None: metadata_dict = { "zarr_format": 3, "node_type": "array", @@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None: "codecs": (), "fill_value": np.datetime64(0, "ns"), } - with pytest.raises(ValueError, match=r".* is not a valid DataType"): + with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"): ArrayV3Metadata.from_dict(metadata_dict) @pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()]) def test_parse_invalid_dtype_raises(data): with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"): - parse_dtype(data) + DataType.parse(data) @pytest.mark.parametrize(