Skip to content

Commit

Permalink
Special case object dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 9, 2024
1 parent aa46b45 commit 483681b
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 8 deletions.
4 changes: 3 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ShapeLike,
ZarrFormat,
concurrent_map,
parse_dtype,
parse_shapelike,
product,
)
Expand Down Expand Up @@ -226,7 +227,8 @@ async def create(
if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = np.dtype(dtype)
dtype = parse_dtype(dtype)
# dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
else:
Expand Down
9 changes: 9 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
overload,
)

import numpy as np

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -162,3 +164,10 @@ def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any) -> np.dtype[Any]:
if dtype is str or dtype == "str":
# special case as object
return np.dtype("object")
return np.dtype(dtype)
7 changes: 1 addition & 6 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -201,11 +201,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
4 changes: 4 additions & 0 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ class DataType(Enum):
complex128 = "complex128"
string = "string"
bytes = "bytes"
object = "object"

@property
def byte_count(self) -> None | int:
Expand Down Expand Up @@ -549,6 +550,7 @@ def to_numpy_shortname(self) -> str:
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
DataType.object: "object",
}
return data_type_to_numpy[self]

Expand All @@ -572,6 +574,8 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
return DataType.string
elif dtype.kind == "S":
return DataType.bytes
elif dtype.kind == "O":
return DataType.object
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand Down
9 changes: 8 additions & 1 deletion tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from itertools import accumulate
from typing import Literal
from typing import Any, Literal

import numpy as np
import pytest
Expand Down Expand Up @@ -406,3 +406,10 @@ def test_vlen_errors() -> None:
dtype="<U4",
codecs=[BytesCodec(), VLenBytesCodec()],
)


@pytest.mark.parametrize("zarr_format", [2, 3, None])
@pytest.mark.parametrize("dtype", [str, "str"])
def test_create_dtype_str(dtype: Any, zarr_format: ZarrFormat | None) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format)
assert arr.dtype.kind == "O"

0 comments on commit 483681b

Please sign in to comment.