Skip to content

Commit

Permalink
feature(store): make list_* methods async generators
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Apr 9, 2024
1 parent 624ff77 commit 7dff5e5
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 62 deletions.
13 changes: 7 additions & 6 deletions src/zarr/v3/abc/store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod, ABC
from collections.abc import AsyncGenerator

from typing import List, Tuple, Optional

Expand Down Expand Up @@ -106,17 +107,17 @@ def supports_listing(self) -> bool:
...

@abstractmethod
async def list(self) -> List[str]:
async def list(self) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store.
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...

@abstractmethod
async def list_prefix(self, prefix: str) -> List[str]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
Parameters
Expand All @@ -125,12 +126,12 @@ async def list_prefix(self, prefix: str) -> List[str]:
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...

@abstractmethod
async def list_dir(self, prefix: str) -> List[str]:
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
“/” after the given prefix.
Expand All @@ -141,6 +142,6 @@ async def list_dir(self, prefix: str) -> List[str]:
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
...
32 changes: 18 additions & 14 deletions src/zarr/v3/group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from dataclasses import asdict, dataclass, field, replace

Expand All @@ -9,7 +10,6 @@
if TYPE_CHECKING:
from typing import (
Any,
AsyncGenerator,
Literal,
AsyncIterator,
Iterator,
Expand Down Expand Up @@ -161,6 +161,7 @@ async def getitem(
) -> AsyncArray | AsyncGroup:

store_path = self.store_path / key
logger.warning("key=%s, store_path=%s", key, store_path)

# Note:
# in zarr-python v2, we first check if `key` references an Array, else if `key` references
Expand Down Expand Up @@ -289,7 +290,7 @@ def __repr__(self):
async def nchildren(self) -> int:
raise NotImplementedError

async def children(self) -> AsyncGenerator[AsyncArray, AsyncGroup]:
async def children(self) -> AsyncGenerator[AsyncArray | AsyncGroup, None]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
This method requires that `store_path.store` supports directory listing.
Expand All @@ -303,18 +304,21 @@ async def children(self) -> AsyncGenerator[AsyncArray, AsyncGroup]:
)

raise ValueError(msg)
subkeys = await self.store_path.store.list_dir(self.store_path.path)
# would be nice to make these special keys accessible programmatically,
# and scoped to specific zarr versions
subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys)
# is there a better way to schedule this?
for subkey in subkeys_filtered:
try:
yield await self.getitem(subkey)
except KeyError:
# keyerror is raised when `subkey``names an object in the store
# in which case `subkey` cannot be the name of a sub-array or sub-group.
pass

async for key in self.store_path.store.list_dir(self.store_path.path):
# these keys are not valid child names so we make sure to skip them
# TODO: it would be nice to make these special keys accessible programmatically,
# and scoped to specific zarr versions
if key not in ("zarr.json", ".zgroup", ".zattrs"):
try:
# TODO: performance optimization -- batch
print(key)
child = await self.getitem(key)
# keyerror is raised when `subkey``names an object in the store
# in which case `subkey` cannot be the name of a sub-array or sub-group.
yield child
except KeyError:
pass

async def contains(self, child: str) -> bool:
raise NotImplementedError
Expand Down
47 changes: 22 additions & 25 deletions src/zarr/v3/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import shutil
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Union, Optional, List, Tuple

Expand Down Expand Up @@ -142,21 +143,19 @@ async def exists(self, key: str) -> bool:
path = self.root / key
return await to_thread(path.is_file)

async def list(self) -> List[str]:
async def list(self) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store.
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
# Q: do we want to return strings or Paths?
def _list(root: Path) -> List[str]:
files = [str(p) for p in root.rglob("") if p.is_file()]
return files
for p in self.root.rglob(""):
if p.is_file():
yield str(p)

return await to_thread(_list, self.root)

async def list_prefix(self, prefix: str) -> List[str]:
async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
Parameters
Expand All @@ -165,16 +164,14 @@ async def list_prefix(self, prefix: str) -> List[str]:
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
for p in (self.root / prefix).rglob("*"):
if p.is_file():
yield str(p)

def _list_prefix(root: Path, prefix: str) -> List[str]:
files = [p for p in (root / prefix).rglob("*") if p.is_file()]
return files

return await to_thread(_list_prefix, self.root, prefix)

async def list_dir(self, prefix: str) -> List[str]:
async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
“/” after the given prefix.
Expand All @@ -185,16 +182,16 @@ async def list_dir(self, prefix: str) -> List[str]:
Returns
-------
list[str]
AsyncGenerator[str, None]
"""
base = self.root / prefix
to_strip = str(base) + "/"

try:
key_iter = base.iterdir()
except (FileNotFoundError, NotADirectoryError):
key_iter = []

def _list_dir(root: Path, prefix: str) -> List[str]:

base = root / prefix
to_strip = str(base) + "/"
try:
return [str(key).replace(to_strip, "") for key in base.iterdir()]
except (FileNotFoundError, NotADirectoryError):
return []
for key in key_iter:
yield str(key).replace(to_strip, "")

return await to_thread(_list_dir, self.root, prefix)
33 changes: 16 additions & 17 deletions src/zarr/v3/store/memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
from typing import Optional, MutableMapping, List, Tuple

from zarr.v3.common import BytesLike
Expand Down Expand Up @@ -67,20 +68,18 @@ async def delete(self, key: str) -> None:
async def set_partial_values(self, key_start_values: List[Tuple[str, int, bytes]]) -> None:
raise NotImplementedError

async def list(self) -> List[str]:
return list(self._store_dict.keys())

async def list_prefix(self, prefix: str) -> List[str]:
return [key for key in self._store_dict if key.startswith(prefix)]

async def list_dir(self, prefix: str) -> List[str]:
if prefix == "":
return list({key.split("/", maxsplit=1)[0] for key in self._store_dict})
else:
return list(
{
key.strip(prefix + "/").split("/")[0]
for key in self._store_dict
if (key.startswith(prefix + "/") and key != prefix)
}
)
async def list(self) -> AsyncGenerator[str, None]:
for key in self._store_dict:
yield key

async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
for key in self._store_dict:
if key.startswith(prefix):
yield key

async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
print('prefix', prefix)
print('keys in list_dir', list(self._store_dict))
for key in self._store_dict:
if key.startswith(prefix + "/") and key != prefix:
yield key.strip(prefix + "/").rsplit("/", maxsplit=1)[0]
5 changes: 5 additions & 0 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@ def test_group_children(store: MemoryStore | LocalStore):
# if group.children guarantees a particular order for the children.
# If order is not guaranteed, then the better version of this test is
# to compare two sets, but presently neither the group nor array classes are hashable.
print('getting children')
observed = group.children
print(observed)
print(list([subgroup, subarray, implicit_subgroup]))
assert len(observed) == 3
assert subarray in observed
assert implicit_subgroup in observed
assert subgroup in observed




@pytest.mark.parametrize("store", (("local", "memory")), indirect=["store"])
def test_group(store: MemoryStore | LocalStore) -> None:
store_path = StorePath(store)
Expand Down

0 comments on commit 7dff5e5

Please sign in to comment.