Skip to content

Commit

Permalink
cleanup format (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
giovp authored Feb 27, 2023
1 parent 38c661c commit 8724f8a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 53 deletions.
1 change: 1 addition & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true
line_length = 120

[Makefile]
indent_style = tab
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Can't yet be moved to the pyproject.toml due to https://github.com/PyCQA/flake8/issues/234
[flake8]
max-line-length = 88
max-line-length = 120
ignore =
# line break before a binary operator -> black does not adhere to PEP8
W503
Expand Down
73 changes: 46 additions & 27 deletions spatialdata/_io/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,11 @@


class SpatialDataFormatV01(CurrentFormat):
"""
SpatialDataFormat defines the format of the spatialdata
package.
"""
"""SpatialDataFormat defines the format of the spatialdata package."""

@property
def spatialdata_version(self) -> str:
return "0.1"

def validate_table(
self,
table: AnnData,
region_key: Optional[str] = None,
instance_key: Optional[str] = None,
) -> None:
if not isinstance(table, AnnData):
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
if region_key is not None:
if not is_categorical_dtype(table.obs[region_key]):
raise ValueError(
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
)
if instance_key is not None:
if table.obs[instance_key].isnull().values.any():
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")
class RasterFormatV01(SpatialDataFormatV01):
"""Formatter for raster data."""

def generate_coordinate_transformations(self, shapes: list[tuple[Any]]) -> Optional[list[list[dict[str, Any]]]]:
data_shape = shapes[0]
Expand Down Expand Up @@ -114,9 +94,13 @@ def channels_from_metadata(self, omero_metadata: dict[str, Any]) -> list[Any]:
return [d["labels"] for d in omero_metadata["channels"]]


class ShapesFormat(SpatialDataFormatV01):
class ShapesFormatV01(SpatialDataFormatV01):
"""Formatter for shapes."""

@property
def version(self) -> str:
return "0.1"

def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:
if Shapes_s.ATTRS_KEY not in metadata:
raise KeyError(f"Missing key {Shapes_s.ATTRS_KEY} in shapes metadata.")
Expand All @@ -129,21 +113,25 @@ def attrs_from_dict(self, metadata: dict[str, Any]) -> GeometryType:

typ = GeometryType(metadata_[Shapes_s.GEOS_KEY][Shapes_s.TYPE_KEY])
assert typ.name == metadata_[Shapes_s.GEOS_KEY][Shapes_s.NAME_KEY]
assert self.spatialdata_version == metadata_["version"]
assert self.version == metadata_["version"]
return typ

def attrs_to_dict(self, geometry: GeometryType) -> dict[str, Union[str, dict[str, Any]]]:
return {Shapes_s.GEOS_KEY: {Shapes_s.NAME_KEY: geometry.name, Shapes_s.TYPE_KEY: geometry.value}}


class PointsFormat(SpatialDataFormatV01):
class PointsFormatV01(SpatialDataFormatV01):
"""Formatter for points."""

@property
def version(self) -> str:
return "0.1"

def attrs_from_dict(self, metadata: dict[str, Any]) -> dict[str, dict[str, Any]]:
if Points_s.ATTRS_KEY not in metadata:
raise KeyError(f"Missing key {Points_s.ATTRS_KEY} in points metadata.")
metadata_ = metadata[Points_s.ATTRS_KEY]
assert self.spatialdata_version == metadata_["version"]
assert self.version == metadata_["version"]
d = {}
if Points_s.FEATURE_KEY in metadata_:
d[Points_s.FEATURE_KEY] = metadata_[Points_s.FEATURE_KEY]
Expand All @@ -159,3 +147,34 @@ def attrs_to_dict(self, data: dict[str, Any]) -> dict[str, dict[str, Any]]:
if Points_s.FEATURE_KEY in data[Points_s.ATTRS_KEY]:
d[Points_s.FEATURE_KEY] = data[Points_s.ATTRS_KEY][Points_s.FEATURE_KEY]
return d


class TablesFormatV01(SpatialDataFormatV01):
"""Formatter for tables."""

@property
def version(self) -> str:
return "0.1"

def validate_table(
self,
table: AnnData,
region_key: Optional[str] = None,
instance_key: Optional[str] = None,
) -> None:
if not isinstance(table, AnnData):
raise TypeError(f"`tables` must be `anndata.AnnData`, was {type(table)}.")
if region_key is not None:
if not is_categorical_dtype(table.obs[region_key]):
raise ValueError(
f"`tables.obs[region_key]` must be of type `categorical`, not `{type(table.obs[region_key])}`."
)
if instance_key is not None:
if table.obs[instance_key].isnull().values.any():
raise ValueError("`tables.obs[instance_key]` must not contain null values, but it does.")


CurrentRasterFormat = RasterFormatV01
CurrentShapesFormat = ShapesFormatV01
CurrentPointsFormat = PointsFormatV01
CurrentTablesFormat = TablesFormatV01
15 changes: 10 additions & 5 deletions spatialdata/_io/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
from spatialdata._core.ngff.ngff_transformations import NgffBaseTransformation
from spatialdata._core.transformations import BaseTransformation
from spatialdata._io._utils import ome_zarr_logger
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
from spatialdata._io.format import (
CurrentPointsFormat,
CurrentRasterFormat,
CurrentShapesFormat,
SpatialDataFormatV01,
)


def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData:
Expand Down Expand Up @@ -122,7 +127,7 @@ def _get_transformations_from_ngff_dict(


def _read_multiscale(
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = SpatialDataFormatV01()
store: str, raster_type: Literal["image", "labels"], fmt: SpatialDataFormatV01 = CurrentRasterFormat()
) -> Union[SpatialImage, MultiscaleSpatialImage]:
assert isinstance(store, str)
assert raster_type in ["image", "labels"]
Expand Down Expand Up @@ -159,7 +164,7 @@ def _read_multiscale(
# if image, read channels metadata
if raster_type == "image":
omero = multiscales[0]["omero"]
channels = fmt.channels_from_metadata(omero)
channels: list[Any] = fmt.channels_from_metadata(omero)
axes = [i["name"] for i in node.metadata["axes"]]
if len(datasets) > 1:
multiscale_image = {}
Expand Down Expand Up @@ -188,7 +193,7 @@ def _read_multiscale(
return compute_coordinates(si)


def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = ShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg]
"""Read shapes from a zarr store."""
assert isinstance(store, str)
f = zarr.open(store, mode="r")
Expand All @@ -212,7 +217,7 @@ def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: Spati


def _read_points(
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = PointsFormat() # type: ignore[type-arg]
store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = CurrentPointsFormat() # type: ignore[type-arg]
) -> DaskDataFrame:
"""Read points from a zarr store."""
assert isinstance(store, str)
Expand Down
34 changes: 18 additions & 16 deletions spatialdata/_io/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
)
from spatialdata._core.models import ShapesModel
from spatialdata._core.transformations import _get_current_output_axes
from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01
from spatialdata._io.format import (
CurrentPointsFormat,
CurrentRasterFormat,
CurrentShapesFormat,
CurrentTablesFormat,
)

__all__ = [
"write_image",
Expand Down Expand Up @@ -89,15 +94,14 @@ def overwrite_coordinate_transformations_raster(
def _write_metadata(
group: zarr.Group,
group_type: str,
# coordinate_transformations: list[dict[str, Any]],
fmt: Format,
axes: Optional[Union[str, list[str], list[dict[str, str]]]] = None,
attrs: Optional[Mapping[str, Any]] = None,
fmt: Format = SpatialDataFormatV01(),
) -> None:
"""Write metdata to a group."""
axes = _get_valid_axes(axes=axes, fmt=fmt)

group.attrs["@type"] = group_type
group.attrs["encoding-type"] = group_type
group.attrs["axes"] = axes
# we write empty coordinateTransformations and then overwrite them with overwrite_coordinate_transformations_non_raster()
group.attrs["coordinateTransformations"] = []
Expand All @@ -110,7 +114,7 @@ def _write_raster(
raster_data: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
label_metadata: Optional[JSONDict] = None,
channels_metadata: Optional[JSONDict] = None,
Expand Down Expand Up @@ -212,7 +216,7 @@ def write_image(
image: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
**metadata: Union[str, JSONDict, list[JSONDict]],
) -> None:
Expand All @@ -231,7 +235,7 @@ def write_labels(
labels: Union[SpatialImage, MultiscaleSpatialImage],
group: zarr.Group,
name: str,
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentRasterFormat(),
storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None,
label_metadata: Optional[JSONDict] = None,
**metadata: JSONDict,
Expand All @@ -253,7 +257,7 @@ def write_shapes(
group: zarr.Group,
name: str,
group_type: str = "ngff:shapes",
fmt: Format = ShapesFormat(),
fmt: Format = CurrentShapesFormat(),
) -> None:
axes = get_dims(shapes)
t = _get_transformations(shapes)
Expand All @@ -268,12 +272,11 @@ def write_shapes(
shapes_group.create_dataset(name=ShapesModel.RADIUS_KEY, data=shapes[ShapesModel.RADIUS_KEY].values)

attrs = fmt.attrs_to_dict(geometry)
attrs["version"] = fmt.spatialdata_version
attrs["version"] = fmt.version

_write_metadata(
shapes_group,
group_type=group_type,
# coordinate_transformations=coordinate_transformations,
axes=list(axes),
attrs=attrs,
fmt=fmt,
Expand All @@ -287,7 +290,7 @@ def write_points(
group: zarr.Group,
name: str,
group_type: str = "ngff:points",
fmt: Format = PointsFormat(),
fmt: Format = CurrentPointsFormat(),
) -> None:
axes = get_dims(points)
t = _get_transformations(points)
Expand All @@ -297,12 +300,11 @@ def write_points(
points.to_parquet(path)

attrs = fmt.attrs_to_dict(points.attrs)
attrs["version"] = fmt.spatialdata_version
attrs["version"] = fmt.version

_write_metadata(
points_groups,
group_type=group_type,
# coordinate_transformations=coordinate_transformations,
axes=list(axes),
attrs=attrs,
fmt=fmt,
Expand All @@ -316,19 +318,19 @@ def write_table(
group: zarr.Group,
name: str,
group_type: str = "ngff:regions_table",
fmt: Format = SpatialDataFormatV01(),
fmt: Format = CurrentTablesFormat(),
) -> None:
region = table.uns["spatialdata_attrs"]["region"]
region_key = table.uns["spatialdata_attrs"].get("region_key", None)
instance_key = table.uns["spatialdata_attrs"].get("instance_key", None)
fmt.validate_table(table, region_key, instance_key)
write_adata(group, name, table) # creates group[name]
tables_group = group[name]
tables_group.attrs["@type"] = group_type
tables_group.attrs["spatialdata-encoding-type"] = group_type
tables_group.attrs["region"] = region
tables_group.attrs["region_key"] = region_key
tables_group.attrs["instance_key"] = instance_key
tables_group.attrs["version"] = fmt.spatialdata_version
tables_group.attrs["version"] = fmt.version


def _iter_multiscale(
Expand Down
38 changes: 34 additions & 4 deletions tests/_io/test_format.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Optional

import pytest
from shapely import GeometryType

from spatialdata._core.models import PointsModel
from spatialdata._io.format import PointsFormat
from spatialdata._core.models import PointsModel, ShapesModel
from spatialdata._io.format import CurrentPointsFormat, CurrentShapesFormat

Points_f = PointsFormat()
Points_f = CurrentPointsFormat()
Shapes_f = CurrentShapesFormat()


class TestFormat:
Expand All @@ -20,7 +22,7 @@ def test_format_points(
feature_key: Optional[str],
instance_key: Optional[str],
) -> None:
metadata: dict[str, Any] = {attrs_key: {"version": Points_f.spatialdata_version}}
metadata: dict[str, Any] = {attrs_key: {"version": Points_f.version}}
format_metadata: dict[str, Any] = {attrs_key: {}}
if feature_key is not None:
metadata[attrs_key][feature_key] = "target"
Expand All @@ -31,3 +33,31 @@ def test_format_points(
assert metadata[attrs_key] == Points_f.attrs_to_dict(format_metadata)
if feature_key is None and instance_key is None:
assert len(format_metadata[attrs_key]) == len(metadata[attrs_key]) == 0

@pytest.mark.parametrize("attrs_key", [ShapesModel.ATTRS_KEY])
@pytest.mark.parametrize("geos_key", [ShapesModel.GEOS_KEY])
@pytest.mark.parametrize("type_key", [ShapesModel.TYPE_KEY])
@pytest.mark.parametrize("name_key", [ShapesModel.NAME_KEY])
@pytest.mark.parametrize("shapes_type", [0, 3, 6])
def test_format_shapes(
self,
attrs_key: str,
geos_key: str,
type_key: str,
name_key: str,
shapes_type: int,
) -> None:
shapes_dict = {
0: "POINT",
3: "POLYGON",
6: "MULTIPOLYGON",
}
metadata: dict[str, Any] = {attrs_key: {"version": Shapes_f.version}}
format_metadata: dict[str, Any] = {attrs_key: {}}
metadata[attrs_key][geos_key] = {}
metadata[attrs_key][geos_key][type_key] = shapes_type
metadata[attrs_key][geos_key][name_key] = shapes_dict[shapes_type]
format_metadata[attrs_key] = Shapes_f.attrs_from_dict(metadata)
metadata[attrs_key].pop("version")
geometry = GeometryType(metadata[attrs_key][geos_key][type_key])
assert metadata[attrs_key] == Shapes_f.attrs_to_dict(geometry)

0 comments on commit 8724f8a

Please sign in to comment.