From 2c6ae3f162d64d83780c2accaa73fa0f498eb4aa Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Tue, 26 Mar 2024 11:51:58 +0100 Subject: [PATCH] Join outside sdata (#512) * add join outside sdata * add join outside sdata * add tests * add tests * silence warnings * Fix all tests * change to new function name * adjust docs --- CHANGELOG.md | 5 +- docs/api.md | 2 +- src/spatialdata/__init__.py | 4 +- .../_core/query/relational_query.py | 146 +++++-- src/spatialdata/dataloader/datasets.py | 10 +- .../operations/test_spatialdata_operations.py | 34 +- tests/core/operations/test_transform.py | 2 +- tests/core/query/test_relational_query.py | 387 ++++++++++++++---- tests/core/query/test_spatial_query.py | 12 +- tests/datasets/test_datasets.py | 4 +- tests/io/test_readwrite.py | 18 +- 11 files changed, 468 insertions(+), 156 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc9bb844..df74c179 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,10 +24,11 @@ and this project adheres to [Semantic Versioning][]. - Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not necessarily so. -- Added SQL like joins that can be executed by calling one public function `join_sdata_spatialelement_table`. The +- Added SQL like joins that can be executed by calling one public function `join_spatialelement_table`. The following joins are supported: `left`, `left_exclusive`, `right`, `right_exclusive` and `inner`. The function has an option to match rows. For `left` only matching `left` is supported and for `right` join only `right` matching of - rows is supported. Not all joins are supported for `Labels` elements. + rows is supported. Not all joins are supported for `Labels` elements. The elements and table can either exist within + a `SpatialData` object or outside. - Added function `match_element_to_table` which allows the user to perform a right join of `SpatialElement`(s) with a table with rows matching the row order in the table. - Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically diff --git a/docs/api.md b/docs/api.md index c3e39478..70f59103 100644 --- a/docs/api.md +++ b/docs/api.md @@ -28,7 +28,7 @@ Operations on `SpatialData` objects. get_values get_extent get_centroids - join_sdata_spatialelement_table + join_spatialelement_table match_element_to_table get_centroids match_table_to_element diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index d899a6ec..18eae0f8 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -23,7 +23,7 @@ "bounding_box_query", "polygon_query", "get_values", - "join_sdata_spatialelement_table", + "join_spatialelement_table", "match_element_to_table", "match_table_to_element", "SpatialData", @@ -49,7 +49,7 @@ from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners from spatialdata._core.query.relational_query import ( get_values, - join_sdata_spatialelement_table, + join_spatialelement_table, match_element_to_table, match_table_to_element, ) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 46713e64..5a24b0a9 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -19,6 +19,8 @@ from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _inplace_fix_subset_categorical_obs from spatialdata.models import ( + Image2DModel, + Image3DModel, Labels2DModel, Labels3DModel, PointsModel, @@ -462,14 +464,78 @@ class MatchTypes(Enum): no = "no" -def join_sdata_spatialelement_table( - sdata: SpatialData, - spatial_element_name: str | list[str], - table_name: str, - how: str = "left", +def _create_sdata_elements_dict_for_join( + sdata: SpatialData, spatial_element_name: str | list[str], table_name: str +) -> tuple[dict[str, dict[str, Any]], AnnData]: + assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object." + table = sdata.tables[table_name] + + elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) + for name in spatial_element_name: + if name in sdata.tables: + warnings.warn( + f"Table: `{name}` given in spatial_element_names cannot be " + f"joined with a table using this function.", + UserWarning, + stacklevel=2, + ) + elif name in sdata.images: + warnings.warn( + f"Image: `{name}` cannot be joined with a table", + UserWarning, + stacklevel=2, + ) + else: + element_type, _, element = sdata._find_element(name) + elements_dict[element_type][name] = element + return elements_dict, table + + +def _create_elements_dict_for_join( + spatial_element_name: str | list[str], elements: SpatialElement | list[SpatialElement], table: AnnData +) -> dict[str, dict[str, Any]]: + elements = elements if isinstance(elements, list) else [elements] + + elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) + for name, element in zip(spatial_element_name, elements): + model = get_model(element) + + if model == TableModel: + warnings.warn( + f"Table: `{name}` given in spatial_element_name cannot be " f"joined with a table using this function.", + UserWarning, + stacklevel=2, + ) + continue + if model in [Image2DModel, Image3DModel]: + warnings.warn( + f"Image: `{name}` cannot be joined with a table", + UserWarning, + stacklevel=2, + ) + continue + + if model in [Labels2DModel, Labels3DModel]: + element_type = "labels" + elif model == PointsModel: + element_type = "points" + elif model == ShapesModel: + element_type = "shapes" + elements_dict[element_type][name] = element + return elements_dict + + +def join_spatialelement_table( + spatial_element_names: str | list[str], + elements: SpatialElement | list[SpatialElement] | None = None, + table: AnnData | None = None, + table_name: str | None = None, + sdata: SpatialData | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left", match_rows: Literal["no", "left", "right"] = "no", ) -> tuple[dict[str, Any], AnnData]: - """Join SpatialElement(s) and table together in SQL like manner. + """ + Join SpatialElement(s) and table together in SQL like manner. The function allows the user to perform SQL like joins of SpatialElements and a table. The elements are not returned together in one dataframe like structure, but instead filtered elements are returned. To determine matches, @@ -489,10 +555,12 @@ def join_sdata_spatialelement_table( Parameters ---------- - sdata - The SpatialData object containing the tables and spatial elements. - spatial_element_name - The name(s) of the spatial elements to be joined with the table. + spatial_element_names + The name(s) of the spatial elements to be joined with the table. If a list of names the indices must match + with the list of SpatialElements passed on by the argument elements. + elements + The SpatialElement(s) to be joined with the table. In case of a list of SpatialElements the indices + must match exactly with the indices in the list of spatial_element_name. table_name The name of the table to join with the spatial elements. how @@ -508,35 +576,41 @@ def join_sdata_spatialelement_table( Raises ------ - AssertionError - If no table with the given table_name exists in the SpatialData object. + ValueError + If table_name is provided but not present in the SpatialData object. + ValueError + If no valid elements are provided for the join operation. ValueError If the provided join type is not supported. + ValueError + If an incorrect value is given for match_rows. """ - assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object." - table = sdata.tables[table_name] - if isinstance(spatial_element_name, str): - spatial_element_name = [spatial_element_name] - - elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict)) - for name in spatial_element_name: - if name in sdata.tables: - warnings.warn( - f"Tables: `{', '.join(elements_dict['tables'].keys())}` given in spatial_element_names cannot be " - f"joined with a table using this function.", - UserWarning, - stacklevel=2, - ) - elif name in sdata.images: - warnings.warn( - f"Images: `{', '.join(elements_dict['images'].keys())}` cannot be joined with a table", - UserWarning, - stacklevel=2, - ) + spatial_element_names = ( + spatial_element_names if isinstance(spatial_element_names, list) else [spatial_element_names] + ) + sdata_args = [sdata, table_name] + non_sdata_args = [elements, table] + if any(arg is not None for arg in sdata_args): + assert all( + arg is None for arg in non_sdata_args + ), "If `sdata` and `table_name` are specified, `elements` and `table` should not be specified." + if sdata is not None and table_name is not None: + elements_dict, table = _create_sdata_elements_dict_for_join(sdata, spatial_element_names, table_name) else: - element_type, _, element = sdata._find_element(name) - elements_dict[element_type][name] = element + raise ValueError("If either `sdata` or `table_name` is specified, both should be specified.") + if any(arg is not None for arg in non_sdata_args): + assert all( + arg is not None for arg in non_sdata_args + ), "both `elements` and `table` must be given if either is specified." + elements_dict = _create_elements_dict_for_join(spatial_element_names, elements, table_name) + + elements_dict, table = _call_join(elements_dict, table, how, match_rows) + return elements_dict, table + +def _call_join( + elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"] +) -> tuple[dict[str, Any], AnnData]: assert any(key in elements_dict for key in ["labels", "shapes", "points"]), ( "No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or " "`shapes`." @@ -611,7 +685,9 @@ def match_element_to_table( ------- A tuple containing the joined elements as a dictionary and the joined table as an AnnData object. """ - element_dict, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "right", match_rows="right") + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names=element_name, table_name=table_name, how="right", match_rows="right" + ) return element_dict, table diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index fb1ae955..b19c2237 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -20,7 +20,7 @@ from spatialdata._core.centroids import get_centroids from spatialdata._core.operations.transform import transform from spatialdata._core.operations.vectorize import to_circles -from spatialdata._core.query.relational_query import _get_unique_label_values_as_index, join_sdata_spatialelement_table +from spatialdata._core.query.relational_query import _get_unique_label_values_as_index, join_spatialelement_table from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, @@ -261,8 +261,12 @@ def _preprocess( if table_name is not None: table_subset = filtered_table[filtered_table.obs[region_key] == region_name] circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy()) - _, table = join_sdata_spatialelement_table( - circles_sdata, region_name, table_name, how="left", match_rows="left" + _, table = join_spatialelement_table( + sdata=circles_sdata, + spatial_element_names=region_name, + table_name=table_name, + how="left", + match_rows="left", ) # get index dictionary, with `instance_id`, `cs`, `region`, and `image` tables_l.append(table) diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index 455caa48..d6f7cb6f 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -129,10 +129,10 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None from spatialdata.models import TableModel rng = np.random.default_rng(seed=0) - full_sdata.table.obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata.table.shape[0]) - adata = full_sdata.table + full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0]) + adata = full_sdata["table"] del adata.uns[TableModel.ATTRS_KEY] - del full_sdata.table + del full_sdata.tables["table"] full_sdata.table = TableModel.parse( adata, region=["circles", "poly"], region_key="annotated_shapes", instance_key="instance_id" ) @@ -145,8 +145,8 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None filtered_sdata1 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space1") filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - assert len(filtered_sdata0.table) + len(filtered_sdata1.table) == len(full_sdata.table) - assert len(filtered_sdata2.table) == len(full_sdata.table) + assert len(filtered_sdata0["table"]) + len(filtered_sdata1["table"]) == len(full_sdata["table"]) + assert len(filtered_sdata2["table"]) == len(full_sdata["table"]) def test_rename_coordinate_systems(full_sdata: SpatialData) -> None: @@ -257,7 +257,7 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: with pytest.raises(KeyError): concatenate([full_sdata, SpatialData(shapes={"circles": full_sdata.shapes["circles"]})]) - assert concatenate([full_sdata, SpatialData()]).table is not None + assert concatenate([full_sdata, SpatialData()])["table"] is not None set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") @@ -268,11 +268,11 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None: # this is needed cause we can't handle regions with same name. # TODO: fix this new_region = "sample2" - table_new = filtered1.table.copy() - del filtered1.table - filtered1.table = table_new - filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region - filtered1.table.obs[filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region + table_new = filtered1["table"].copy() + del filtered1.tables["table"] + filtered1["table"] = table_new + filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region + filtered1["table"].obs[filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region concatenated = concatenate([filtered0, filtered1], concatenate_tables=True) assert len(list(concatenated.gen_elements())) == 3 @@ -341,20 +341,20 @@ def test_subset(full_sdata: SpatialData) -> None: assert "image3d_xarray" in full_sdata.images assert unique_names == set(element_names) # no table since the labels are not present in the subset - assert subset0.table is None + assert "table" not in subset0.tables adata = AnnData( shape=(10, 0), obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]}, ) - del full_sdata.table + del full_sdata.tables["table"] sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id") - full_sdata.table = sdata_table + full_sdata["table"] = sdata_table full_sdata.tables["second_table"] = sdata_table subset1 = full_sdata.subset(["poly", "second_table"]) - assert subset1.table is not None - assert len(subset1.table) == 5 - assert subset1.table.obs["region"].unique().tolist() == ["poly"] + assert subset1["table"] is not None + assert len(subset1["table"]) == 5 + assert subset1["table"].obs["region"].unique().tolist() == ["poly"] assert len(subset1["second_table"]) == 10 diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index 5ad61dec..98610f79 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -549,7 +549,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop( labels=dict(full_sdata.labels), points=dict(full_sdata.points), shapes=dict(full_sdata.shapes), - table=full_sdata.table, + table=full_sdata["table"], ) temp["transformed_element"] = transformed_element transformation = get_transformation_between_coordinate_systems( diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index db4cef3c..96c61139 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -7,7 +7,7 @@ _get_element_annotators, _locate_value, _ValueOrigin, - join_sdata_spatialelement_table, + join_spatialelement_table, ) from spatialdata.models.models import TableModel @@ -36,54 +36,110 @@ def test_join_using_string_instance_id_and_index(sdata_query_aggregation): sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner" + element_dict, table = join_spatialelement_table( + sdata=sdata_query_aggregation, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="inner", ) # Note that we started with 21 n_obs. assert table.n_obs == 10 - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right_exclusive" + element_dict, table = join_spatialelement_table( + sdata=sdata_query_aggregation, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="right_exclusive", ) assert table.n_obs == 11 - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right" + element_dict, table = join_spatialelement_table( + sdata=sdata_query_aggregation, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="right", ) assert table.n_obs == 21 def test_left_inner_right_exclusive_join(sdata_query_aggregation): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, "values_polygons", "table", "right_exclusive" + sdata = sdata_query_aggregation + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names="values_polygons", table_name="table", how="right_exclusive" + ) + assert table is None + assert all(element_dict[key] is None for key in element_dict) + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_polygons"], + elements=[sdata["values_polygons"]], + table=sdata["table"], + how="right_exclusive", ) assert table is None assert all(element_dict[key] is None for key in element_dict) - sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"].drop([10, 11]) + sdata["values_polygons"] = sdata["values_polygons"].drop([10, 11]) with pytest.raises(AssertionError, match="No table with"): - join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "not_existing_table", "left") + join_spatialelement_table( + sdata=sdata, spatial_element_names="values_polygons", table_name="not_existing_table", how="left" + ) # Should we reindex before returning the table? - element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "values_polygons", "table", "left") + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names="values_polygons", table_name="table", how="left" + ) + assert all(element_dict["values_polygons"].index == table.obs["instance_id"].values) + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_polygons"], elements=[sdata["values_polygons"]], table=sdata["table"], how="left" + ) assert all(element_dict["values_polygons"].index == table.obs["instance_id"].values) # Check no matches in table for element not annotated by table - element_dict, table = join_sdata_spatialelement_table(sdata_query_aggregation, "by_polygons", "table", "left") + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names="by_polygons", table_name="table", how="left" + ) assert table is None - assert element_dict["by_polygons"] is sdata_query_aggregation["by_polygons"] + assert element_dict["by_polygons"] is sdata["by_polygons"] + + element_dict, table = join_spatialelement_table( + spatial_element_names=["by_polygons"], elements=[sdata["by_polygons"]], table=sdata["table"], how="left" + ) + assert table is None + assert element_dict["by_polygons"] is sdata["by_polygons"] # Check multiple elements, one of which not annotated by table with pytest.warns(UserWarning, match="The element"): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["by_polygons", "values_polygons"], "table", "left" + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names=["by_polygons", "values_polygons"], table_name="table", how="left" + ) + assert "by_polygons" in element_dict + + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_spatialelement_table( + spatial_element_names=["by_polygons", "values_polygons"], + elements=[sdata["by_polygons"], sdata["values_polygons"]], + table=sdata["table"], + how="left", ) assert "by_polygons" in element_dict # check multiple elements joined to table. - sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"].drop([7, 8]) - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + sdata["values_circles"] = sdata["values_circles"].drop([7, 8]) + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], table_name="table", how="left" + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="left", ) indices = pd.concat( [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] @@ -91,8 +147,23 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): assert all(table.obs["instance_id"] == indices.values) with pytest.warns(UserWarning, match="The element"): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right_exclusive" + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + table_name="table", + how="right_exclusive", + ) + assert all(element_dict[key] is None for key in element_dict) + assert all(table.obs.index == ["7", "8", "19", "20"]) + assert all(table.obs["instance_id"].values == [7, 8, 10, 11]) + assert all(table.obs["region"].values == ["values_circles", "values_circles", "values_polygons", "values_polygons"]) + + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"], sdata["by_polygons"]], + table=sdata["table"], + how="right_exclusive", ) assert all(element_dict[key] is None for key in element_dict) assert all(table.obs.index == ["7", "8", "19", "20"]) @@ -101,8 +172,24 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): # the triggered warning is: UserWarning: The element `{name}` is not annotated by the table. Skipping with pytest.warns(UserWarning, match="The element"): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "inner" + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + table_name="table", + how="inner", + ) + indices = pd.concat( + [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] + ) + assert all(table.obs["instance_id"] == indices.values) + assert element_dict["by_polygons"] is None + + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"], sdata["by_polygons"]], + table=sdata["table"], + how="inner", ) indices = pd.concat( [element_dict["values_circles"].index.to_series(), element_dict["values_polygons"].index.to_series()] @@ -112,86 +199,210 @@ def test_left_inner_right_exclusive_join(sdata_query_aggregation): def test_join_spatialelement_table_fail(full_sdata): - with pytest.warns(UserWarning, match="Images:"): - join_sdata_spatialelement_table(full_sdata, ["image2d", "labels2d"], "table", "left_exclusive") - with pytest.warns(UserWarning, match="Tables:"): - join_sdata_spatialelement_table(full_sdata, ["labels2d", "table"], "table", "left_exclusive") + with pytest.warns(UserWarning, match="Image:"): + join_spatialelement_table( + sdata=full_sdata, spatial_element_names=["image2d", "labels2d"], table_name="table", how="left_exclusive" + ) + with pytest.warns(UserWarning, match="Table:"): + join_spatialelement_table( + sdata=full_sdata, spatial_element_names=["labels2d", "table"], table_name="table", how="left_exclusive" + ) with pytest.raises(TypeError, match="`not_join` is not a"): - join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "not_join") + join_spatialelement_table( + sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="not_join" + ) def test_left_exclusive_and_right_join(sdata_query_aggregation): + sdata = sdata_query_aggregation # Test case in which all table rows match rows in elements - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="left_exclusive", + ) + assert all(element_dict[key] is None for key in element_dict) + assert table is None + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="left_exclusive", ) assert all(element_dict[key] is None for key in element_dict) assert table is None # Dropped indices correspond to instance ids 7, 8 for 'values_circles' and 10, 11 for 'values_polygons' - sdata_query_aggregation["table"] = sdata_query_aggregation["table"][ - sdata_query_aggregation["table"].obs.index.drop(["7", "8", "19", "20"]) - ] + sdata["table"] = sdata["table"][sdata["table"].obs.index.drop(["7", "8", "19", "20"])] + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_polygons", "by_polygons"], + table_name="table", + how="left_exclusive", + ) + assert table is None + assert not set(element_dict["values_polygons"].index).issubset(sdata["table"].obs["instance_id"]) + with pytest.warns(UserWarning, match="The element"): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_polygons", "by_polygons"], "table", "left_exclusive" + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_polygons", "by_polygons"], + elements=[sdata["values_polygons"], sdata["by_polygons"]], + table=sdata["table"], + how="left_exclusive", ) assert table is None - assert not set(element_dict["values_polygons"].index).issubset(sdata_query_aggregation["table"].obs["instance_id"]) + assert not set(element_dict["values_polygons"].index).issubset(sdata["table"].obs["instance_id"]) # test right join with pytest.warns(UserWarning, match="The element"): - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons", "by_polygons"], "table", "right" + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + table_name="table", + how="right", ) - assert table is sdata_query_aggregation["table"] + assert table is sdata["table"] assert not {7, 8}.issubset(element_dict["values_circles"].index) assert not {10, 11}.issubset(element_dict["values_polygons"].index) - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left_exclusive" + with pytest.warns(UserWarning, match="The element"): + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons", "by_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"], sdata["by_polygons"]], + table=sdata["table"], + how="right", + ) + assert table is sdata["table"] + assert not {7, 8}.issubset(element_dict["values_circles"].index) + assert not {10, 11}.issubset(element_dict["values_polygons"].index) + + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="left_exclusive", ) assert table is None assert not np.array_equal( - sdata_query_aggregation["table"].obs.iloc[7:9]["instance_id"].values, + sdata["table"].obs.iloc[7:9]["instance_id"].values, element_dict["values_circles"].index.values, ) assert not np.array_equal( - sdata_query_aggregation["table"].obs.iloc[19:21]["instance_id"].values, + sdata["table"].obs.iloc[19:21]["instance_id"].values, + element_dict["values_polygons"].index.values, + ) + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="left_exclusive", + ) + assert table is None + assert not np.array_equal( + sdata["table"].obs.iloc[7:9]["instance_id"].values, + element_dict["values_circles"].index.values, + ) + assert not np.array_equal( + sdata["table"].obs.iloc[19:21]["instance_id"].values, element_dict["values_polygons"].index.values, ) def test_match_rows_join(sdata_query_aggregation): + sdata = sdata_query_aggregation reversed_instance_id = [3, 4, 5, 6, 7, 8, 1, 2, 0] + list(reversed(range(12))) - original_instance_id = sdata_query_aggregation.table.obs["instance_id"] - sdata_query_aggregation.table.obs["instance_id"] = reversed_instance_id + original_instance_id = sdata["table"].obs["instance_id"] + sdata["table"].obs["instance_id"] = reversed_instance_id - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left", match_rows="left" + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="left", + match_rows="left", ) assert all(table.obs["instance_id"].values == original_instance_id.values) - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right", match_rows="right" + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="left", + match_rows="left", + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="right", + match_rows="right", ) indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] assert all(indices == table.obs["instance_id"]) - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="left" + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="right", + match_rows="right", + ) + assert all(indices == table.obs["instance_id"]) + + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="inner", + match_rows="left", ) assert all(table.obs["instance_id"].values == original_instance_id.values) - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner", match_rows="right" + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="inner", + match_rows="left", + ) + assert all(table.obs["instance_id"].values == original_instance_id.values) + + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="inner", + match_rows="right", ) indices = [*element_dict["values_circles"].index, *element_dict[("values_polygons")].index] assert all(indices == table.obs["instance_id"]) + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="inner", + match_rows="right", + ) + assert all(indices == table.obs["instance_id"]) + # check whether table ordering is preserved if not matching - element_dict, table = join_sdata_spatialelement_table( - sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "left" + element_dict, table = join_spatialelement_table( + sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], table_name="table", how="left" + ) + assert all(table.obs["instance_id"] == reversed_instance_id) + + element_dict, table = join_spatialelement_table( + spatial_element_names=["values_circles", "values_polygons"], + elements=[sdata["values_circles"], sdata["values_polygons"]], + table=sdata["table"], + how="left", ) assert all(table.obs["instance_id"] == reversed_instance_id) @@ -324,7 +535,7 @@ def test_get_values_df(sdata_query_aggregation): assert v.shape == (9, 1) # test with multiple values, in the obs - sdata_query_aggregation.table.obs["another_numerical_in_obs"] = v + sdata_query_aggregation["table"].obs["another_numerical_in_obs"] = v v = get_values( value_key=["numerical_in_obs", "another_numerical_in_obs"], sdata=sdata_query_aggregation, @@ -341,14 +552,14 @@ def test_get_values_df(sdata_query_aggregation): # test with multiple values, in the var # prepare the data - adata = sdata_query_aggregation.table + adata = sdata_query_aggregation["table"] X = adata.X new_X = np.hstack([X, X[:, 0:1]]) new_adata = AnnData( X=new_X, obs=adata.obs, var=pd.DataFrame(index=["numerical_in_var", "another_numerical_in_var"]), uns=adata.uns ) - del sdata_query_aggregation.table - sdata_query_aggregation.table = new_adata + del sdata_query_aggregation.tables["table"] + sdata_query_aggregation["table"] = new_adata # test v = get_values( value_key=["numerical_in_var", "another_numerical_in_var"], @@ -360,7 +571,7 @@ def test_get_values_df(sdata_query_aggregation): # test exceptions # value found in multiple locations - sdata_query_aggregation.table.obs["another_numerical_in_gdf"] = np.zeros(21) + sdata_query_aggregation["table"].obs["another_numerical_in_gdf"] = np.zeros(21) with pytest.raises(ValueError): get_values( value_key="another_numerical_in_gdf", @@ -418,35 +629,41 @@ def test_filter_table_categorical_bug(shapes): adata.obs["cell_id"] = np.arange(len(adata)) adata = TableModel.parse(adata, region=["circles"], region_key="region", instance_key="cell_id") adata_subset = adata[adata.obs["categorical"] == "a"].copy() - shapes.table = adata_subset + shapes["table"] = adata_subset shapes.filter_by_coordinate_system("global") def test_labels_table_joins(full_sdata): - element_dict, table = join_sdata_spatialelement_table( - full_sdata, - "labels2d", - "table", - "left", + element_dict, table = join_spatialelement_table( + sdata=full_sdata, + spatial_element_names="labels2d", + table_name="table", + how="left", ) assert all(table.obs["instance_id"] == range(100)) full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) - element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left", match_rows="left") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="left", match_rows="left" + ) assert all(table.obs["instance_id"] == range(100)) with pytest.warns(UserWarning, match="Element type"): - join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "left_exclusive") + join_spatialelement_table( + sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="left_exclusive" + ) with pytest.warns(UserWarning, match="Element type"): - join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "inner") + join_spatialelement_table(sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="inner") with pytest.warns(UserWarning, match="Element type"): - join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right") + join_spatialelement_table(sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="right") # all labels are present in table so should return None - element_dict, table = join_sdata_spatialelement_table(full_sdata, "labels2d", "table", "right_exclusive") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="labels2d", table_name="table", how="right_exclusive" + ) assert element_dict["labels2d"] is None assert table is None @@ -455,7 +672,9 @@ def test_points_table_joins(full_sdata): full_sdata["table"].uns["spatialdata_attrs"]["region"] = "points_0" full_sdata["table"].obs["region"] = ["points_0"] * 100 - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="left" + ) # points should have the same number of rows as before and table as well assert len(element_dict["points_0"]) == 300 @@ -463,28 +682,40 @@ def test_points_table_joins(full_sdata): full_sdata["table"].obs["instance_id"] = list(reversed(range(100))) - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left", match_rows="left") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="left", match_rows="left" + ) assert len(element_dict["points_0"]) == 300 assert all(table.obs["instance_id"] == range(100)) # We have 100 table instances so resulting length of points should be 200 as we started with 300 - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "left_exclusive") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="left_exclusive" + ) assert len(element_dict["points_0"]) == 200 assert table is None - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "inner") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="inner" + ) assert len(element_dict["points_0"]) == 100 assert all(table.obs["instance_id"] == list(reversed(range(100)))) - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="right" + ) assert len(element_dict["points_0"]) == 100 assert all(table.obs["instance_id"] == list(reversed(range(100)))) - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right", match_rows="right") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="right", match_rows="right" + ) assert all(element_dict["points_0"].index.values.compute() == list(reversed(range(100)))) assert all(table.obs["instance_id"] == list(reversed(range(100)))) - element_dict, table = join_sdata_spatialelement_table(full_sdata, "points_0", "table", "right_exclusive") + element_dict, table = join_spatialelement_table( + sdata=full_sdata, spatial_element_names="points_0", table_name="table", how="right_exclusive" + ) assert element_dict["points_0"] is None assert table is None diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index a043cab0..3152af91 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -409,15 +409,15 @@ def test_query_filter_table(with_polygon_query: bool): target_coordinate_system="global", ) - assert len(queried0.table) == 1 - assert len(queried1.table) == 3 + assert len(queried0["table"]) == 1 + assert len(queried1["table"]) == 3 def test_polygon_query_with_multipolygon(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - tables=sdata.table, + tables=sdata["table"], ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] @@ -432,19 +432,19 @@ def test_polygon_query_with_multipolygon(sdata_query_aggregation): ) assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 - assert len(queried.table) == 8 + assert len(queried["table"]) == 8 multipolygon = GeoDataFrame(geometry=[polygon, circle_pol]).unary_union queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 8 assert len(queried["values_circles"]) == 8 - assert len(queried.table) == 16 + assert len(queried["table"]) == 16 multipolygon = GeoDataFrame(geometry=[polygon, polygon]).unary_union queried = polygon_query(values_sdata, polygon=multipolygon, target_coordinate_system="global") assert len(queried["values_polygons"]) == 4 assert len(queried["values_circles"]) == 4 - assert len(queried.table) == 8 + assert len(queried["table"]) == 8 PLOT = False if PLOT: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 5979f637..e22182b4 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -5,7 +5,7 @@ def test_datasets() -> None: extra_cs = "test" sdata_blobs = blobs(extra_coord_system=extra_cs) - assert len(sdata_blobs.table) == 26 + assert len(sdata_blobs["table"]) == 26 assert len(sdata_blobs.shapes["blobs_circles"]) == 5 assert len(sdata_blobs.shapes["blobs_polygons"]) == 5 assert len(sdata_blobs.shapes["blobs_multipolygons"]) == 2 @@ -19,7 +19,7 @@ def test_datasets() -> None: _ = str(sdata_blobs) sdata_raccoon = raccoon() - assert sdata_raccoon.table is None + assert "table" not in sdata_raccoon.tables assert len(sdata_raccoon.shapes["circles"]) == 4 assert sdata_raccoon.images["raccoon"].shape == (3, 768, 1024) assert sdata_raccoon.labels["segmentation"].shape == (768, 1024) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 81d43864..7a0b51af 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -73,9 +73,9 @@ def _test_table(self, tmp_path: str, table: SpatialData) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" table.write(tmpdir) sdata = SpatialData.read(tmpdir) - pd.testing.assert_frame_equal(table.table.obs, sdata.table.obs) + pd.testing.assert_frame_equal(table["table"].obs, sdata["table"].obs) try: - assert table.table.uns == sdata.table.uns + assert table["table"].uns == sdata["table"].uns except ValueError as e: raise e @@ -305,20 +305,20 @@ def test_io_table(shapes): adata.obs["instance"] = shapes.shapes["circles"].index adata = TableModel().parse(adata, region="circles", region_key="region", instance_key="instance") shapes.table = adata - del shapes.table + del shapes.tables["table"] shapes.table = adata with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") shapes.write(f) shapes2 = SpatialData.read(f) - assert shapes2.table is not None - assert shapes2.table.shape == (5, 10) + assert "table" in shapes2.tables + assert shapes2["table"].shape == (5, 10) - del shapes2.table - assert shapes2.table is None + del shapes2.tables["table"] + assert "table" not in shapes2.tables shapes2.table = adata - assert shapes2.table is not None - assert shapes2.table.shape == (5, 10) + assert "table" in shapes2.tables + assert shapes2["table"].shape == (5, 10) def test_bug_rechunking_after_queried_raster():