Skip to content

Commit

Permalink
Support for consolidated remote zarr (#278)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
berombau and LucaMarconato authored Jul 14, 2023
1 parent cdddb8b commit 9aa7525
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 66 deletions.
43 changes: 26 additions & 17 deletions src/spatialdata/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,48 @@
This module provides command line interface (CLI) interactions for the SpatialData library, allowing users to perform
various operations through a terminal. Currently, it implements the "peek" function, which allows users to inspect
the contents of a SpatialData .zarr file. Additional CLI functionalities will be implemented in the future.
the contents of a SpatialData .zarr dataset. Additional CLI functionalities will be implemented in the future.
"""
import os
from typing import Literal

import click


@click.command(help="Peek inside the SpatialData .zarr file")
@click.command(help="Peek inside the SpatialData .zarr dataset")
@click.argument("path", default=False, type=str)
def peek(path: str) -> None:
@click.argument("selection", type=click.Choice(["images", "labels", "points", "shapes", "table"]), nargs=-1)
def peek(path: str, selection: tuple[Literal["images", "labels", "points", "shapes", "table"]]) -> None:
"""
Peek inside the SpatialData .zarr file.
Peek inside the SpatialData .zarr dataset.
This function takes a path to a .zarr file, checks if it is a valid directory, and then reads and prints
its contents using the SpatialData library.
This function takes a path to a local or remote .zarr dataset, reads and prints
its contents using the SpatialData library. If any ValueError is raised, it is caught and printed to the
terminal along with a help message.
Parameters
----------
path
The path to the .zarr file to be inspected.
The path to the .zarr dataset to be inspected.
selection
Optional, a list of keys (among images, labels, points, shapes, table) to load only a subset of the dataset.
Example: `python -m spatialdata peek data.zarr images labels`
"""
if not os.path.isdir(path):
import spatialdata as sd

try:
sdata = sd.SpatialData.read(path, selection=selection)
print(sdata) # noqa: T201
except ValueError as e:
# checking if a valid path was provided is difficult given the various ways in which
# a possibly remote path and storage access options can be specified
# so we just catch the ValueError and print a help message
print(e) # noqa: T201
print( # noqa: T201
f"Error: .zarr storage not found at {path}. Please specify a valid OME-NGFF spatial data (.zarr) file. "
"Example "
'"python -m '
'spatialdata peek data.zarr"'
"Examples "
'"python -m spatialdata peek data.zarr"'
'"python -m spatialdata peek https://remote/.../data.zarr labels table"'
)
else:
import spatialdata as sd

sdata = sd.SpatialData.read(path)
print(sdata) # noqa: T201


@click.group()
Expand Down
29 changes: 25 additions & 4 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import zarr
from anndata import AnnData
from dask.dataframe import read_parquet
from dask.dataframe.core import DataFrame as DaskDataFrame
from dask.delayed import Delayed
from geopandas import GeoDataFrame
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from ome_zarr.io import parse_url
from ome_zarr.types import JSONDict
from pyarrow.parquet import read_table
from spatial_image import SpatialImage

from spatialdata._io import (
Expand Down Expand Up @@ -917,6 +917,7 @@ def write(
file_path: str | Path,
storage_options: JSONDict | list[JSONDict] | None = None,
overwrite: bool = False,
consolidate_metadata: bool = True,
) -> None:
"""Write the SpatialData object to Zarr."""
if isinstance(file_path, str):
Expand Down Expand Up @@ -1052,6 +1053,12 @@ def write(
self.path = None
raise e

if consolidate_metadata:
# consolidate metadata to more easily support remote reading
# bug in zarr, 'zmetadata' is written instead of '.zmetadata'
# see discussion https://github.com/zarr-developers/zarr-python/issues/1121
zarr.consolidate_metadata(store, metadata_key=".zmetadata")

# old code to support overwriting the backing file
# if target_path is not None:
# if os.path.isdir(file_path):
Expand Down Expand Up @@ -1133,10 +1140,24 @@ def table(self) -> None:
del root["table/table"]

@staticmethod
def read(file_path: str) -> SpatialData:
def read(file_path: str, selection: tuple[str] | None = None) -> SpatialData:
"""
Read a SpatialData object from a Zarr storage (on-disk or remote).
Parameters
----------
file_path
The path or URL to the Zarr storage.
selection
The elements to read (images, labels, points, shapes, table). If None, all elements are read.
Returns
-------
The SpatialData object.
"""
from spatialdata import read_zarr

return read_zarr(file_path)
return read_zarr(file_path, selection=selection)

@property
def images(self) -> dict[str, SpatialImage | MultiscaleSpatialImage]:
Expand Down Expand Up @@ -1238,7 +1259,7 @@ def h(s: str) -> str:
assert isinstance(t, tuple)
assert len(t) == 1
parquet_file = t[0]
table = read_table(parquet_file)
table = read_parquet(parquet_file)
length = len(table)
else:
# length = len(v)
Expand Down
7 changes: 5 additions & 2 deletions src/spatialdata/_io/io_points.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from collections.abc import MutableMapping
from pathlib import Path
from typing import Union
Expand Down Expand Up @@ -29,8 +30,10 @@ def _read_points(
assert isinstance(store, (str, Path))
f = zarr.open(store, mode="r")

path = Path(f._store.path) / f.path / "points.parquet"
table = read_parquet(path)
path = os.path.join(f._store.path, f.path, "points.parquet")
# cache on remote file needed for parquet reader to work
# TODO: allow reading in the metadata without caching all the data
table = read_parquet("simplecache::" + path if "http" in path else path)
assert isinstance(table, DaskDataFrame)

transformations = _get_transformations_from_ngff_dict(f.attrs.asdict()["coordinateTransformations"])
Expand Down
155 changes: 114 additions & 41 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,145 @@
import logging
import os
from pathlib import Path
from typing import Optional, Union

import numpy as np
import zarr
from anndata import AnnData
from anndata import read_zarr as read_anndata_zarr
from anndata.experimental import read_elem

from spatialdata import SpatialData
from spatialdata._io._utils import ome_zarr_logger
from spatialdata._io.io_points import _read_points
from spatialdata._io.io_raster import _read_multiscale
from spatialdata._io.io_shapes import _read_shapes
from spatialdata._logging import logger
from spatialdata.models import TableModel


def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData:
if isinstance(store, str):
store = Path(store)
def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, str]:
"""
Open a zarr store (on-disk or remote) and return the zarr.Group object and the path to the store.
Parameters
----------
store
Path to the zarr store (on-disk or remote) or a zarr.Group object.
Returns
-------
A tuple of the zarr.Group object and the path to the store.
"""
f = store if isinstance(store, zarr.Group) else zarr.open(store, mode="r")
# workaround: .zmetadata is being written as zmetadata (https://github.com/zarr-developers/zarr-python/issues/1121)
if isinstance(store, (str, Path)) and str(store).startswith("http") and len(f) == 0:
f = zarr.open_consolidated(store, mode="r", metadata_key="zmetadata")
f_store_path = f.store.store.path if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore) else f.store.path
return f, f_store_path


def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str]] = None) -> SpatialData:
"""
Read a SpatialData dataset from a zarr store (on-disk or remote).
Parameters
----------
store
Path to the zarr store (on-disk or remote) or a zarr.Group object.
selection
List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are
read.
Returns
-------
A SpatialData object.
"""
f, f_store_path = _open_zarr_store(store)

f = zarr.open(store, mode="r")
images = {}
labels = {}
points = {}
table: Optional[AnnData] = None
shapes = {}

selector = {"images", "labels", "points", "shapes", "table"} if not selection else set(selection or [])
logger.debug(f"Reading selection {selector}")

# read multiscale images
images_store = store / "images"
if images_store.exists():
f = zarr.open(images_store, mode="r")
for k in f:
f_elem = f[k].name
f_elem_store = f"{images_store}{f_elem}"
images[k] = _read_multiscale(f_elem_store, raster_type="image")
if "images" in selector and "images" in f:
group = f["images"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem = group[subgroup_name]
f_elem_store = os.path.join(f_store_path, f_elem.path)
element = _read_multiscale(f_elem_store, raster_type="image")
images[subgroup_name] = element
count += 1
logger.debug(f"Found {count} elements in {group}")

# read multiscale labels
with ome_zarr_logger(logging.ERROR):
labels_store = store / "labels"
if labels_store.exists():
f = zarr.open(labels_store, mode="r")
for k in f:
f_elem = f[k].name
f_elem_store = f"{labels_store}{f_elem}"
labels[k] = _read_multiscale(f_elem_store, raster_type="labels")
if "labels" in selector and "labels" in f:
group = f["labels"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem = group[subgroup_name]
f_elem_store = os.path.join(f_store_path, f_elem.path)
labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels")
count += 1
logger.debug(f"Found {count} elements in {group}")

# now read rest of the data
points_store = store / "points"
if points_store.exists():
f = zarr.open(points_store, mode="r")
for k in f:
f_elem = f[k].name
f_elem_store = f"{points_store}{f_elem}"
points[k] = _read_points(f_elem_store)

shapes_store = store / "shapes"
if shapes_store.exists():
f = zarr.open(shapes_store, mode="r")
for k in f:
f_elem = f[k].name
f_elem_store = f"{shapes_store}{f_elem}"
shapes[k] = _read_shapes(f_elem_store)

table_store = store / "table"
if table_store.exists():
f = zarr.open(table_store, mode="r")
for k in f:
f_elem = f[k].name
f_elem_store = f"{table_store}{f_elem}"
table = read_anndata_zarr(f_elem_store)
if "points" in selector and "points" in f:
group = f["points"]
count = 0
for subgroup_name in group:
f_elem = group[subgroup_name]
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem_store = os.path.join(f_store_path, f_elem.path)
points[subgroup_name] = _read_points(f_elem_store)
count += 1
logger.debug(f"Found {count} elements in {group}")

if "shapes" in selector and "shapes" in f:
group = f["shapes"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem = group[subgroup_name]
f_elem_store = os.path.join(f_store_path, f_elem.path)
shapes[subgroup_name] = _read_shapes(f_elem_store)
count += 1
logger.debug(f"Found {count} elements in {group}")

if "table" in selector and "table" in f:
group = f["table"]
count = 0
for subgroup_name in group:
if Path(subgroup_name).name.startswith("."):
# skip hidden files like .zgroup or .zmetadata
continue
f_elem = group[subgroup_name]
f_elem_store = os.path.join(f_store_path, f_elem.path)
if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore):
table = read_elem(f_elem)
# we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5)
# https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183
# table = read_anndata_zarr(f_elem)
else:
table = read_anndata_zarr(f_elem_store)
if TableModel.ATTRS_KEY in table.uns:
# fill out eventual missing attributes that has been omitted because their value was None
attrs = table.uns[TableModel.ATTRS_KEY]
Expand All @@ -81,6 +152,8 @@ def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData:
# fix type for region
if "region" in attrs and isinstance(attrs["region"], np.ndarray):
attrs["region"] = attrs["region"].tolist()
count += 1
logger.debug(f"Found {count} elements in {group}")

sdata = SpatialData(
images=images,
Expand Down
6 changes: 4 additions & 2 deletions src/spatialdata/_logging.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import logging
import os


def _setup_logger() -> "logging.Logger":
from rich.console import Console
from rich.logging import RichHandler

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
level = os.environ.get("LOGLEVEL", logging.INFO)
logger.setLevel(level=level)
console = Console(force_terminal=True)
if console.is_jupyter is True:
console.is_jupyter = False
ch = RichHandler(show_path=False, console=console, show_time=False)
ch = RichHandler(show_path=False, console=console, show_time=logger.level == logging.DEBUG)
logger.addHandler(ch)

# this prevents double outputs
Expand Down

0 comments on commit 9aa7525

Please sign in to comment.