diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 76de711c..9ed18071 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -58,31 +58,6 @@ def get_element_annotators(sdata: SpatialData, element_name: str) -> set[str]: return table_names -def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: - """ - Filter an AnnData table to keep only the rows that are in the coordinate system. - - Parameters - ---------- - table - The table to filter; if None, returns None - element_names - The element_names to keep in the tables obs.region column - - Returns - ------- - The filtered table, or None if the input table was None - """ - if table is None or not table.uns.get(TableModel.ATTRS_KEY): - return None - table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] - region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] - table.obs = pd.DataFrame(table.obs) - table = table[table.obs[region_key].isin(element_names)].copy() - table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() - return table - - @singledispatch def get_element_instances( element: SpatialElement, @@ -110,8 +85,10 @@ def get_element_instances( def _( element: DataArray | DataTree, return_background: bool = False, -) -> pd.Index: +) -> pd.Index | None: model = get_model(element) + if model in [Image2DModel, Image3DModel]: + return None assert model in [Labels2DModel, Labels3DModel], "Expected a `Labels` element. Found an `Image` instead." if isinstance(element, DataArray): # get unique labels value (including 0 if present) @@ -145,8 +122,8 @@ def _( # TODO: replace function use throughout repo by `join_sdata_spatialelement_table` def _filter_table_by_elements( - table: AnnData | None, elements_dict: dict[str, dict[str, Any]], match_rows: bool = False -) -> AnnData | None: + table: AnnData | list[AnnData], elements_dict: dict[str, dict[str, Any]], match_rows: bool = False +) -> AnnData: """ Filter an AnnData table to keep only the rows that are in the elements. @@ -163,42 +140,38 @@ def _filter_table_by_elements( ------- The filtered table (eventually with reordered rows), or None if the input table was None. """ - assert set(elements_dict.keys()).issubset({"images", "labels", "shapes", "points"}) - assert len(elements_dict) > 0, "elements_dict must not be empty" - assert any( - len(elements) > 0 for elements in elements_dict.values() - ), "elements_dict must contain at least one dict which contains at least one element" - if table is None: - return None + + def _validate_elements_dict(elements_dict: dict[str, dict[str, Any]]) -> None: + assert set(elements_dict.keys()).issubset({"images", "labels", "shapes", "points"}) + assert len(elements_dict) > 0, "elements_dict must not be empty" + assert any( + len(elements) > 0 for elements in elements_dict.values() + ), "elements_dict must contain at least one dict which contains at least one element" + + def _get_matching_indices( + table: AnnData, region_key: str, instance_key: str, name: str, instances: ArrayLike + ) -> ArrayLike: + return ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() + + def _filter_table(table: AnnData, to_keep: ArrayLike) -> AnnData: + table.obs = pd.DataFrame(table.obs) + return table[to_keep, :] + + _validate_elements_dict(elements_dict) to_keep = np.zeros(len(table), dtype=bool) - region_key = table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] - instance_key = table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] - instances = None - for _, elements in elements_dict.items(): + _, region_key, instance_key = get_table_keys(table) + + for elements in elements_dict.values(): for name, element in elements.items(): - if get_model(element) == Labels2DModel or get_model(element) == Labels3DModel: - if isinstance(element, DataArray): - # get unique labels value (including 0 if present) - instances = da.unique(element.data).compute() - else: - assert isinstance(element, DataTree) - v = element["scale0"].values() - assert len(v) == 1 - xdata = next(iter(v)) - # can be slow - instances = da.unique(xdata.data).compute() - instances = np.sort(instances) - elif get_model(element) == ShapesModel: - instances = element.index.to_numpy() - elif get_model(element) == PointsModel: - instances = element.compute().index.to_numpy() - else: - continue - indices = ((table.obs[region_key] == name) & (table.obs[instance_key].isin(instances))).to_numpy() - to_keep = to_keep | indices + model = get_model(element) + instances = get_element_instances(element) + if instances is not None: + indices = _get_matching_indices(table, region_key, instance_key, name, instances) + to_keep |= indices + original_table = table - table.obs = pd.DataFrame(table.obs) - table = table[to_keep, :] + table = _filter_table(table, to_keep) + if match_rows: assert instances is not None assert isinstance(instances, np.ndarray) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 1c3affa2..6fa01b48 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -735,10 +735,17 @@ def _filter_tables( continue # each mode here requires paths or elements, using assert here to avoid mypy errors. if by == "cs": - from spatialdata._core.query.relational_query import _filter_table_by_element_names + from spatialdata._core.query.relational_query import _filter_table_by_elements assert element_names is not None - table = _filter_table_by_element_names(table, element_names) + elements_dict = {} + for element_type in ["images", "labels", "shapes", "points"]: + elements = getattr(self, element_type) + if elements: # Check if the dictionary is not empty + elements_dict[element_type] = { + name: elements[name] for name in element_names if name in elements + } + table = _filter_table_by_elements(table, elements_dict=elements_dict) if len(table) != 0: tables[table_name] = table elif by == "elements": diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 8a59147f..19a055a8 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -135,9 +135,15 @@ def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None: from spatialdata.models import TableModel - rng = np.random.default_rng(seed=0) - full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) - adata = full_sdata["table"] + adata = full_sdata["table"].copy() + + circles_instances = full_sdata["circles"].index.values + poly_instances = full_sdata["poly"].index.values + + adata = adata[: len(circles_instances) + len(poly_instances), :].copy() + adata.obs["annotated_shapes"] = ["circles"] * len(circles_instances) + ["poly"] * len(poly_instances) + adata.obs["instance_id"] = np.concatenate([circles_instances, poly_instances]) + del adata.uns[TableModel.ATTRS_KEY] del full_sdata.tables["table"] full_sdata.table = TableModel.parse(