diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index e1a13b6b..9bc2af09 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -176,6 +176,8 @@ def aggregate( by_ = transform(by_, to_coordinate_system=target_coordinate_system) values_ = transform(values_, to_coordinate_system=target_coordinate_system) + table_name = table_name if table_name is not None else "table" + # dispatch adata = None if by_type is ShapesModel and values_type in [PointsModel, ShapesModel]: @@ -218,7 +220,6 @@ def aggregate( if adata is None: raise NotImplementedError(f"Cannot aggregate {values_type} by {by_type}") - table_name = table_name if table_name is not None else "table" # create a SpatialData object with the aggregated table and the "by" shapes shapes_name = by if isinstance(by, str) else "by" return _create_sdata_from_table_and_shapes( diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index b90d915f..f9dde693 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -532,6 +532,26 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group: element_type_group = root.require_group(element_type) return element_type_group.require_group(name) + def _group_for_element_exists(self, name: str, element_type: str) -> bool: + """ + Check if the group for an element exists. + + Parameters + ---------- + name + name of the element + element_type + type of the element. Should be in ["images", "labels", "points", "polygons", "shapes"]. + + Returns + ------- + True if the group exists, False otherwise. + """ + store = parse_url(self.path, mode="r").store + root = zarr.group(store=store) + assert element_type in ["images", "labels", "points", "polygons", "shapes"] + return element_type in root and name in root[element_type] + def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> zarr.Group: store = parse_url(self.path, mode="r+").store root = zarr.group(store=store) @@ -608,24 +628,32 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: if self.path is not None: for path in located: found_element_type, found_element_name = path.split("/") - group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) - axes = get_axes_names(element) - if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + if self._group_for_element_exists(found_element_name, found_element_type): + group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) + axes = get_axes_names(element) + if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_raster, + ) - overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) - elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + overwrite_coordinate_transformations_raster( + group=group, axes=axes, transformations=transformations + ) + elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_non_raster, + ) - overwrite_coordinate_transformations_non_raster( - group=group, axes=axes, transformations=transformations - ) + overwrite_coordinate_transformations_non_raster( + group=group, axes=axes, transformations=transformations + ) + else: + raise ValueError("Unknown element type") else: - raise ValueError("Unknown element type") + logger.info( + f"Not saving the transformation to element {found_element_type}/{found_element_name} as it is" + " not found in Zarr storage" + ) @deprecation_alias(filter_table="filter_tables") def filter_by_coordinate_system( diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 8ea40728..fb1ae955 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -471,7 +471,7 @@ def _get_tile_coords( # extent, aka the tile size extent = (circles.radius * 2).values.reshape(-1, 1) - centroids_points = get_centroids(circles) + centroids_points = get_centroids(circles, coordinate_system=cs) axes = get_axes_names(centroids_points) centroids_numpy = centroids_points.compute().values