Skip to content

Commit

Permalink
Join outside sdata (scverse#512)
Browse files Browse the repository at this point in the history
* add join outside sdata

* add join outside sdata

* add tests

* add tests

* silence warnings

* Fix all tests

* change to new function name

* adjust docs
  • Loading branch information
melonora authored Mar 26, 2024
1 parent 16162a3 commit 2c6ae3f
Show file tree
Hide file tree
Showing 11 changed files with 468 additions and 156 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ and this project adheres to [Semantic Versioning][].

- Implemented support in SpatialData for storing multiple tables. These tables can annotate a SpatialElement but not
necessarily so.
- Added SQL like joins that can be executed by calling one public function `join_sdata_spatialelement_table`. The
- Added SQL like joins that can be executed by calling one public function `join_spatialelement_table`. The
following joins are supported: `left`, `left_exclusive`, `right`, `right_exclusive` and `inner`. The function has
an option to match rows. For `left` only matching `left` is supported and for `right` join only `right` matching of
rows is supported. Not all joins are supported for `Labels` elements.
rows is supported. Not all joins are supported for `Labels` elements. The elements and table can either exist within
a `SpatialData` object or outside.
- Added function `match_element_to_table` which allows the user to perform a right join of `SpatialElement`(s) with a
table with rows matching the row order in the table.
- Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically
Expand Down
2 changes: 1 addition & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Operations on `SpatialData` objects.
get_values
get_extent
get_centroids
join_sdata_spatialelement_table
join_spatialelement_table
match_element_to_table
get_centroids
match_table_to_element
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"bounding_box_query",
"polygon_query",
"get_values",
"join_sdata_spatialelement_table",
"join_spatialelement_table",
"match_element_to_table",
"match_table_to_element",
"SpatialData",
Expand All @@ -49,7 +49,7 @@
from spatialdata._core.query._utils import circles_to_polygons, get_bounding_box_corners
from spatialdata._core.query.relational_query import (
get_values,
join_sdata_spatialelement_table,
join_spatialelement_table,
match_element_to_table,
match_table_to_element,
)
Expand Down
146 changes: 111 additions & 35 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import _inplace_fix_subset_categorical_obs
from spatialdata.models import (
Image2DModel,
Image3DModel,
Labels2DModel,
Labels3DModel,
PointsModel,
Expand Down Expand Up @@ -462,14 +464,78 @@ class MatchTypes(Enum):
no = "no"


def join_sdata_spatialelement_table(
sdata: SpatialData,
spatial_element_name: str | list[str],
table_name: str,
how: str = "left",
def _create_sdata_elements_dict_for_join(
sdata: SpatialData, spatial_element_name: str | list[str], table_name: str
) -> tuple[dict[str, dict[str, Any]], AnnData]:
assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object."
table = sdata.tables[table_name]

elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict))
for name in spatial_element_name:
if name in sdata.tables:
warnings.warn(
f"Table: `{name}` given in spatial_element_names cannot be "
f"joined with a table using this function.",
UserWarning,
stacklevel=2,
)
elif name in sdata.images:
warnings.warn(
f"Image: `{name}` cannot be joined with a table",
UserWarning,
stacklevel=2,
)
else:
element_type, _, element = sdata._find_element(name)
elements_dict[element_type][name] = element
return elements_dict, table


def _create_elements_dict_for_join(
spatial_element_name: str | list[str], elements: SpatialElement | list[SpatialElement], table: AnnData
) -> dict[str, dict[str, Any]]:
elements = elements if isinstance(elements, list) else [elements]

elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict))
for name, element in zip(spatial_element_name, elements):
model = get_model(element)

if model == TableModel:
warnings.warn(
f"Table: `{name}` given in spatial_element_name cannot be " f"joined with a table using this function.",
UserWarning,
stacklevel=2,
)
continue
if model in [Image2DModel, Image3DModel]:
warnings.warn(
f"Image: `{name}` cannot be joined with a table",
UserWarning,
stacklevel=2,
)
continue

if model in [Labels2DModel, Labels3DModel]:
element_type = "labels"
elif model == PointsModel:
element_type = "points"
elif model == ShapesModel:
element_type = "shapes"
elements_dict[element_type][name] = element
return elements_dict


def join_spatialelement_table(
spatial_element_names: str | list[str],
elements: SpatialElement | list[SpatialElement] | None = None,
table: AnnData | None = None,
table_name: str | None = None,
sdata: SpatialData | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "left",
match_rows: Literal["no", "left", "right"] = "no",
) -> tuple[dict[str, Any], AnnData]:
"""Join SpatialElement(s) and table together in SQL like manner.
"""
Join SpatialElement(s) and table together in SQL like manner.
The function allows the user to perform SQL like joins of SpatialElements and a table. The elements are not
returned together in one dataframe like structure, but instead filtered elements are returned. To determine matches,
Expand All @@ -489,10 +555,12 @@ def join_sdata_spatialelement_table(
Parameters
----------
sdata
The SpatialData object containing the tables and spatial elements.
spatial_element_name
The name(s) of the spatial elements to be joined with the table.
spatial_element_names
The name(s) of the spatial elements to be joined with the table. If a list of names the indices must match
with the list of SpatialElements passed on by the argument elements.
elements
The SpatialElement(s) to be joined with the table. In case of a list of SpatialElements the indices
must match exactly with the indices in the list of spatial_element_name.
table_name
The name of the table to join with the spatial elements.
how
Expand All @@ -508,35 +576,41 @@ def join_sdata_spatialelement_table(
Raises
------
AssertionError
If no table with the given table_name exists in the SpatialData object.
ValueError
If table_name is provided but not present in the SpatialData object.
ValueError
If no valid elements are provided for the join operation.
ValueError
If the provided join type is not supported.
ValueError
If an incorrect value is given for match_rows.
"""
assert sdata.tables.get(table_name), f"No table with `{table_name}` exists in the SpatialData object."
table = sdata.tables[table_name]
if isinstance(spatial_element_name, str):
spatial_element_name = [spatial_element_name]

elements_dict: dict[str, dict[str, Any]] = defaultdict(lambda: defaultdict(dict))
for name in spatial_element_name:
if name in sdata.tables:
warnings.warn(
f"Tables: `{', '.join(elements_dict['tables'].keys())}` given in spatial_element_names cannot be "
f"joined with a table using this function.",
UserWarning,
stacklevel=2,
)
elif name in sdata.images:
warnings.warn(
f"Images: `{', '.join(elements_dict['images'].keys())}` cannot be joined with a table",
UserWarning,
stacklevel=2,
)
spatial_element_names = (
spatial_element_names if isinstance(spatial_element_names, list) else [spatial_element_names]
)
sdata_args = [sdata, table_name]
non_sdata_args = [elements, table]
if any(arg is not None for arg in sdata_args):
assert all(
arg is None for arg in non_sdata_args
), "If `sdata` and `table_name` are specified, `elements` and `table` should not be specified."
if sdata is not None and table_name is not None:
elements_dict, table = _create_sdata_elements_dict_for_join(sdata, spatial_element_names, table_name)
else:
element_type, _, element = sdata._find_element(name)
elements_dict[element_type][name] = element
raise ValueError("If either `sdata` or `table_name` is specified, both should be specified.")
if any(arg is not None for arg in non_sdata_args):
assert all(
arg is not None for arg in non_sdata_args
), "both `elements` and `table` must be given if either is specified."
elements_dict = _create_elements_dict_for_join(spatial_element_names, elements, table_name)

elements_dict, table = _call_join(elements_dict, table, how, match_rows)
return elements_dict, table


def _call_join(
elements_dict: dict[str, dict[str, Any]], table: AnnData, how: str, match_rows: Literal["no", "left", "right"]
) -> tuple[dict[str, Any], AnnData]:
assert any(key in elements_dict for key in ["labels", "shapes", "points"]), (
"No valid element to join in spatial_element_name. Must provide at least one of either `labels`, `points` or "
"`shapes`."
Expand Down Expand Up @@ -611,7 +685,9 @@ def match_element_to_table(
-------
A tuple containing the joined elements as a dictionary and the joined table as an AnnData object.
"""
element_dict, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "right", match_rows="right")
element_dict, table = join_spatialelement_table(
sdata=sdata, spatial_element_names=element_name, table_name=table_name, how="right", match_rows="right"
)
return element_dict, table


Expand Down
10 changes: 7 additions & 3 deletions src/spatialdata/dataloader/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from spatialdata._core.centroids import get_centroids
from spatialdata._core.operations.transform import transform
from spatialdata._core.operations.vectorize import to_circles
from spatialdata._core.query.relational_query import _get_unique_label_values_as_index, join_sdata_spatialelement_table
from spatialdata._core.query.relational_query import _get_unique_label_values_as_index, join_spatialelement_table
from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import (
Image2DModel,
Expand Down Expand Up @@ -261,8 +261,12 @@ def _preprocess(
if table_name is not None:
table_subset = filtered_table[filtered_table.obs[region_key] == region_name]
circles_sdata = SpatialData.init_from_elements({region_name: circles}, tables=table_subset.copy())
_, table = join_sdata_spatialelement_table(
circles_sdata, region_name, table_name, how="left", match_rows="left"
_, table = join_spatialelement_table(
sdata=circles_sdata,
spatial_element_names=region_name,
table_name=table_name,
how="left",
match_rows="left",
)
# get index dictionary, with `instance_id`, `cs`, `region`, and `image`
tables_l.append(table)
Expand Down
34 changes: 17 additions & 17 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None
from spatialdata.models import TableModel

rng = np.random.default_rng(seed=0)
full_sdata.table.obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata.table.shape[0])
adata = full_sdata.table
full_sdata["table"].obs["annotated_shapes"] = rng.choice(["circles", "poly"], size=full_sdata["table"].shape[0])
adata = full_sdata["table"]
del adata.uns[TableModel.ATTRS_KEY]
del full_sdata.table
del full_sdata.tables["table"]
full_sdata.table = TableModel.parse(
adata, region=["circles", "poly"], region_key="annotated_shapes", instance_key="instance_id"
)
Expand All @@ -145,8 +145,8 @@ def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None
filtered_sdata1 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space1")
filtered_sdata2 = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False)

assert len(filtered_sdata0.table) + len(filtered_sdata1.table) == len(full_sdata.table)
assert len(filtered_sdata2.table) == len(full_sdata.table)
assert len(filtered_sdata0["table"]) + len(filtered_sdata1["table"]) == len(full_sdata["table"])
assert len(filtered_sdata2["table"]) == len(full_sdata["table"])


def test_rename_coordinate_systems(full_sdata: SpatialData) -> None:
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None:
with pytest.raises(KeyError):
concatenate([full_sdata, SpatialData(shapes={"circles": full_sdata.shapes["circles"]})])

assert concatenate([full_sdata, SpatialData()]).table is not None
assert concatenate([full_sdata, SpatialData()])["table"] is not None

set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0")
set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1")
Expand All @@ -268,11 +268,11 @@ def test_concatenate_sdatas(full_sdata: SpatialData) -> None:
# this is needed cause we can't handle regions with same name.
# TODO: fix this
new_region = "sample2"
table_new = filtered1.table.copy()
del filtered1.table
filtered1.table = table_new
filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region
filtered1.table.obs[filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region
table_new = filtered1["table"].copy()
del filtered1.tables["table"]
filtered1["table"] = table_new
filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region
filtered1["table"].obs[filtered1["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region
concatenated = concatenate([filtered0, filtered1], concatenate_tables=True)
assert len(list(concatenated.gen_elements())) == 3

Expand Down Expand Up @@ -341,20 +341,20 @@ def test_subset(full_sdata: SpatialData) -> None:
assert "image3d_xarray" in full_sdata.images
assert unique_names == set(element_names)
# no table since the labels are not present in the subset
assert subset0.table is None
assert "table" not in subset0.tables

adata = AnnData(
shape=(10, 0),
obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]},
)
del full_sdata.table
del full_sdata.tables["table"]
sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id")
full_sdata.table = sdata_table
full_sdata["table"] = sdata_table
full_sdata.tables["second_table"] = sdata_table
subset1 = full_sdata.subset(["poly", "second_table"])
assert subset1.table is not None
assert len(subset1.table) == 5
assert subset1.table.obs["region"].unique().tolist() == ["poly"]
assert subset1["table"] is not None
assert len(subset1["table"]) == 5
assert subset1["table"].obs["region"].unique().tolist() == ["poly"]
assert len(subset1["second_table"]) == 10


Expand Down
2 changes: 1 addition & 1 deletion tests/core/operations/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def test_transform_elements_and_entire_spatial_data_object_multi_hop(
labels=dict(full_sdata.labels),
points=dict(full_sdata.points),
shapes=dict(full_sdata.shapes),
table=full_sdata.table,
table=full_sdata["table"],
)
temp["transformed_element"] = transformed_element
transformation = get_transformation_between_coordinate_systems(
Expand Down
Loading

0 comments on commit 2c6ae3f

Please sign in to comment.