Skip to content

Commit

Permalink
fixes in get_centroids, spatial_query for multiscale raster, dataload…
Browse files Browse the repository at this point in the history
…er (#497)

* fixes in get_centroids, spatial_query for multiscale raster, dataloader

* fix coordinate system dataloader

* fixed bug in save_transformation when sdata not None

* default value for table_name in aggregate

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
LucaMarconato and Ubuntu authored Mar 21, 2024
1 parent aa78ad6 commit f455ab4
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
3 changes: 2 additions & 1 deletion src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def aggregate(
by_ = transform(by_, to_coordinate_system=target_coordinate_system)
values_ = transform(values_, to_coordinate_system=target_coordinate_system)

table_name = table_name if table_name is not None else "table"

# dispatch
adata = None
if by_type is ShapesModel and values_type in [PointsModel, ShapesModel]:
Expand Down Expand Up @@ -218,7 +220,6 @@ def aggregate(
if adata is None:
raise NotImplementedError(f"Cannot aggregate {values_type} by {by_type}")

table_name = table_name if table_name is not None else "table"
# create a SpatialData object with the aggregated table and the "by" shapes
shapes_name = by if isinstance(by, str) else "by"
return _create_sdata_from_table_and_shapes(
Expand Down
58 changes: 43 additions & 15 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,26 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group:
element_type_group = root.require_group(element_type)
return element_type_group.require_group(name)

def _group_for_element_exists(self, name: str, element_type: str) -> bool:
"""
Check if the group for an element exists.
Parameters
----------
name
name of the element
element_type
type of the element. Should be in ["images", "labels", "points", "polygons", "shapes"].
Returns
-------
True if the group exists, False otherwise.
"""
store = parse_url(self.path, mode="r").store
root = zarr.group(store=store)
assert element_type in ["images", "labels", "points", "polygons", "shapes"]
return element_type in root and name in root[element_type]

def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> zarr.Group:
store = parse_url(self.path, mode="r+").store
root = zarr.group(store=store)
Expand Down Expand Up @@ -608,24 +628,32 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None:
if self.path is not None:
for path in located:
found_element_type, found_element_name = path.split("/")
group = self._get_group_for_element(name=found_element_name, element_type=found_element_type)
axes = get_axes_names(element)
if isinstance(element, (SpatialImage, MultiscaleSpatialImage)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)
if self._group_for_element_exists(found_element_name, found_element_type):
group = self._get_group_for_element(name=found_element_name, element_type=found_element_type)
axes = get_axes_names(element)
if isinstance(element, (SpatialImage, MultiscaleSpatialImage)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)

overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations)
elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)
overwrite_coordinate_transformations_raster(
group=group, axes=axes, transformations=transformations
)
elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)

overwrite_coordinate_transformations_non_raster(
group=group, axes=axes, transformations=transformations
)
overwrite_coordinate_transformations_non_raster(
group=group, axes=axes, transformations=transformations
)
else:
raise ValueError("Unknown element type")
else:
raise ValueError("Unknown element type")
logger.info(
f"Not saving the transformation to element {found_element_type}/{found_element_name} as it is"
" not found in Zarr storage"
)

@deprecation_alias(filter_table="filter_tables")
def filter_by_coordinate_system(
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def _get_tile_coords(

# extent, aka the tile size
extent = (circles.radius * 2).values.reshape(-1, 1)
centroids_points = get_centroids(circles)
centroids_points = get_centroids(circles, coordinate_system=cs)
axes = get_axes_names(centroids_points)
centroids_numpy = centroids_points.compute().values

Expand Down

0 comments on commit f455ab4

Please sign in to comment.