Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change ArrayV3Metadata.data_type to DataType #2278

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 37 additions & 25 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
if TYPE_CHECKING:
from typing import Self

import numpy.typing as npt

from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.chunk_grids import ChunkGrid
from zarr.core.common import JSON, ChunkCoords
Expand All @@ -20,6 +18,7 @@

import numcodecs.abc
import numpy as np
import numpy.typing as npt

from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec, Codec
from zarr.core.array_spec import ArraySpec
Expand All @@ -31,6 +30,8 @@
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class

DEFAULT_DTYPE = "float64"


def parse_zarr_format(data: object) -> Literal[3]:
if data == 3:
Expand Down Expand Up @@ -152,7 +153,7 @@ def _replace_special_floats(obj: object) -> Any:
@dataclass(frozen=True, kw_only=True)
class ArrayV3Metadata(ArrayMetadata):
shape: ChunkCoords
data_type: np.dtype[Any]
data_type: DataType
chunk_grid: ChunkGrid
chunk_key_encoding: ChunkKeyEncoding
fill_value: Any
Expand All @@ -167,7 +168,7 @@ def __init__(
self,
*,
shape: Iterable[int],
data_type: npt.DTypeLike,
data_type: npt.DTypeLike | DataType,
chunk_grid: dict[str, JSON] | ChunkGrid,
chunk_key_encoding: dict[str, JSON] | ChunkKeyEncoding,
fill_value: Any,
Expand All @@ -180,18 +181,18 @@ def __init__(
Because the class is a frozen dataclass, we set attributes using object.__setattr__
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(data_type)
data_type_parsed = DataType.parse(data_type)
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy())
attributes_parsed = parse_attributes(attributes)
codecs_parsed_partial = parse_codecs(codecs)
storage_transformers_parsed = parse_storage_transformers(storage_transformers)

array_spec = ArraySpec(
shape=shape_parsed,
dtype=data_type_parsed,
dtype=data_type_parsed.to_numpy(),
fill_value=fill_value_parsed,
order="C", # TODO: order is not needed here.
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
Expand Down Expand Up @@ -224,11 +225,14 @@ def _validate_metadata(self) -> None:
if self.fill_value is None:
raise ValueError("`fill_value` is required.")
for codec in self.codecs:
codec.validate(shape=self.shape, dtype=self.data_type, chunk_grid=self.chunk_grid)
codec.validate(
shape=self.shape, dtype=self.data_type.to_numpy(), chunk_grid=self.chunk_grid
)

@property
def dtype(self) -> np.dtype[Any]:
return self.data_type
"""Interpret Zarr dtype as NumPy dtype"""
return self.data_type.to_numpy()

@property
def ndim(self) -> int:
Expand Down Expand Up @@ -266,13 +270,13 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
_ = DataType(_data["data_type"])
data_type = DataType.parse(_data.pop("data_type"))

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
_data["attributes"] = _data.pop("attributes", None)
return cls(**_data) # type: ignore[arg-type]
return cls(**_data, data_type=data_type) # type: ignore[arg-type]

def to_dict(self) -> dict[str, JSON]:
out_dict = super().to_dict()
Expand Down Expand Up @@ -490,8 +494,11 @@ def to_numpy_shortname(self) -> str:
}
return data_type_to_numpy[self]

def to_numpy(self) -> np.dtype[Any]:
return np.dtype(self.to_numpy_shortname())

@classmethod
def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand All @@ -511,16 +518,21 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
try:
dtype = np.dtype(data)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
@classmethod
def parse(cls, dtype: None | DataType | Any) -> DataType:
if dtype is None:
# the default dtype
return DataType[DEFAULT_DTYPE]
if isinstance(dtype, DataType):
return dtype
else:
try:
dtype = np.dtype(dtype)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
# check that this is a valid v3 data_type
try:
data_type = DataType.from_numpy(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e
return data_type
11 changes: 5 additions & 6 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -22,7 +22,6 @@

from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)
Expand Down Expand Up @@ -209,7 +208,7 @@ def test_metadata_to_dict(
storage_transformers: None | tuple[dict[str, JSON]],
) -> None:
shape = (1, 2, 3)
data_type = "uint8"
data_type = DataType.uint8
if chunk_grid == "regular":
cgrid = {"name": "regular", "configuration": {"chunk_shape": (1, 1, 1)}}

Expand Down Expand Up @@ -290,7 +289,7 @@ def test_metadata_to_dict(
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
Expand All @@ -301,14 +300,14 @@ async def test_invalid_dtype_raises() -> None:
"codecs": (),
"fill_value": np.datetime64(0, "ns"),
}
with pytest.raises(ValueError, match=r".* is not a valid DataType"):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
ArrayV3Metadata.from_dict(metadata_dict)


@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type: .*"):
parse_dtype(data)
DataType.parse(data)


@pytest.mark.parametrize(
Expand Down