Skip to content

Commit

Permalink
Test joins with string indices and instance id (#485)
Browse files Browse the repository at this point in the history
* test join strings

* fix dtype aggregate

---------

Co-authored-by: Luca Marconato <[email protected]>
  • Loading branch information
melonora and LucaMarconato authored Mar 14, 2024
1 parent 09e339e commit a2970d3
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 29 deletions.
9 changes: 8 additions & 1 deletion src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,14 @@ def _create_sdata_from_table_and_shapes(
) -> SpatialData:
from spatialdata._core._deepcopy import deepcopy as _deepcopy

table.obs[instance_key] = table.obs_names.copy()
shapes_index_dtype = shapes.index.dtype if isinstance(shapes, GeoDataFrame) else shapes.dtype
try:
table.obs[instance_key] = table.obs_names.copy().astype(shapes_index_dtype)
except ValueError as err:
raise TypeError(
f"Instance key column dtype in table resulting from aggregation cannot be cast to the dtype of"
f"element {shapes_name}.index"
) from err
table.obs[region_key] = shapes_name
table = TableModel.parse(table, region=shapes_name, region_key=region_key, instance_key=instance_key)

Expand Down
7 changes: 7 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ def validate_table_in_spatialdata(self, table: AnnData) -> None:
else:
dtype = element.index.dtype
if dtype != table.obs[instance_key].dtype:
if dtype == str or table.obs[instance_key].dtype == str:
raise TypeError(
f"Table instance_key column ({instance_key}) has a dtype "
f"({table.obs[instance_key].dtype}) that does not match the dtype of the indices of "
f"the annotated element ({dtype})."
)

warnings.warn(
(
f"Table instance_key column ({instance_key}) has a dtype "
Expand Down
15 changes: 6 additions & 9 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage
from multiscale_spatial_image.to_multiscale.to_multiscale import Methods
from pandas import CategoricalDtype
from pandas.errors import IntCastingNaNError
from shapely._geometry import GeometryType
from shapely.geometry import MultiPolygon, Point, Polygon
from shapely.geometry.collection import GeometryCollection
Expand Down Expand Up @@ -795,6 +794,11 @@ def _validate_table_annotation_metadata(self, data: AnnData) -> None:
raise ValueError(f"`{attr[self.REGION_KEY_KEY]}` not found in `adata.obs`.")
if attr[self.INSTANCE_KEY] not in data.obs:
raise ValueError(f"`{attr[self.INSTANCE_KEY]}` not found in `adata.obs`.")
if (dtype := data.obs[attr[self.INSTANCE_KEY]].dtype) not in [np.int16, np.int32, np.int64, str]:
raise TypeError(
f"Only np.int16, np.int32, np.int64 or string allowed as dtype for "
f"instance_key column in obs. Dtype found to be {dtype}"
)
expected_regions = attr[self.REGION_KEY] if isinstance(attr[self.REGION_KEY], list) else [attr[self.REGION_KEY]]
found_regions = data.obs[attr[self.REGION_KEY_KEY]].unique().tolist()
if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0:
Expand Down Expand Up @@ -881,14 +885,6 @@ def parse(
adata.obs[region_key] = pd.Categorical(adata.obs[region_key])
if instance_key is None:
raise ValueError("`instance_key` must be provided.")
if adata.obs[instance_key].dtype != int:
try:
warnings.warn(
f"Converting `{cls.INSTANCE_KEY}: {instance_key}` to integer dtype.", UserWarning, stacklevel=2
)
adata.obs[instance_key] = adata.obs[instance_key].astype(int)
except IntCastingNaNError as exc:
raise ValueError("Values within table.obs[] must be able to be coerced to int dtype.") from exc

grouped = adata.obs.groupby(region_key, observed=True)
grouped_size = grouped.size()
Expand All @@ -901,6 +897,7 @@ def parse(

attr = {"region": region, "region_key": region_key, "instance_key": instance_key}
adata.uns[cls.ATTRS_KEY] = attr
cls().validate(adata)
return adata


Expand Down
11 changes: 2 additions & 9 deletions tests/core/operations/test_spatialdata_operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -419,10 +418,7 @@ def test_validate_table_in_spatialdata(full_sdata):
region, region_key, _ = get_table_keys(table)
assert region == "labels2d"

# no warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
full_sdata.validate_table_in_spatialdata(table)
full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.labels["labels2d"] = Labels2DModel.parse(full_sdata.labels["labels2d"].astype("int16"))
Expand All @@ -437,10 +433,7 @@ def test_validate_table_in_spatialdata(full_sdata):
table.obs[region_key] = "points_0"
full_sdata.set_table_annotates_spatialelement("table", region="points_0")

# no warnings
with warnings.catch_warnings():
warnings.simplefilter("error")
full_sdata.validate_table_in_spatialdata(table)
full_sdata.validate_table_in_spatialdata(table)

# dtype mismatch
full_sdata.points["points_0"].index = full_sdata.points["points_0"].index.astype("int16")
Expand Down
31 changes: 31 additions & 0 deletions tests/core/query/test_relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,37 @@ def test_match_table_to_element(sdata_query_aggregation):
# TODO: add tests for labels


def test_join_using_string_instance_id_and_index(sdata_query_aggregation):
sdata_query_aggregation["table"].obs["instance_id"] = [
f"string_{i}" for i in sdata_query_aggregation["table"].obs["instance_id"]
]
sdata_query_aggregation["values_circles"].index = pd.Index(
[f"string_{i}" for i in sdata_query_aggregation["values_circles"].index]
)
sdata_query_aggregation["values_polygons"].index = pd.Index(
[f"string_{i}" for i in sdata_query_aggregation["values_polygons"].index]
)

sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5]
sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5]

element_dict, table = join_sdata_spatialelement_table(
sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "inner"
)
# Note that we started with 21 n_obs.
assert table.n_obs == 10

element_dict, table = join_sdata_spatialelement_table(
sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right_exclusive"
)
assert table.n_obs == 11

element_dict, table = join_sdata_spatialelement_table(
sdata_query_aggregation, ["values_circles", "values_polygons"], "table", "right"
)
assert table.n_obs == 21


def test_left_inner_right_exclusive_join(sdata_query_aggregation):
element_dict, table = join_sdata_spatialelement_table(
sdata_query_aggregation, "values_polygons", "table", "right_exclusive"
Expand Down
18 changes: 8 additions & 10 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,14 @@ def test_table_model(
region: str | np.ndarray,
) -> None:
region_key = "reg"
obs = pd.DataFrame(
RNG.choice(np.arange(0, 100, dtype=float), size=(10, 3), replace=False), columns=["A", "B", "C"]
)
obs[region_key] = region
adata = AnnData(RNG.normal(size=(10, 2)), obs=obs)
with pytest.raises(TypeError, match="Only np.int16"):
model.parse(adata, region=region, region_key=region_key, instance_key="A")

obs = pd.DataFrame(RNG.choice(np.arange(0, 100), size=(10, 3), replace=False), columns=["A", "B", "C"])
obs[region_key] = region
adata = AnnData(RNG.normal(size=(10, 2)), obs=obs)
Expand All @@ -332,16 +340,6 @@ def test_table_model(
assert TableModel.REGION_KEY_KEY in table.uns[TableModel.ATTRS_KEY]
assert table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] == region

obs["A"] = obs["A"].astype(str)
adata = AnnData(RNG.normal(size=(10, 2)), obs=obs)
with pytest.warns(UserWarning, match="Converting"):
model.parse(adata, region=region, region_key=region_key, instance_key="A")

obs["A"] = pd.Series(len([chr(ord("a") + i) for i in range(10)]))
adata = AnnData(RNG.normal(size=(10, 2)), obs=obs)
with pytest.raises(ValueError, match="Values within"):
model.parse(adata, region=region, region_key=region_key, instance_key="A")

@pytest.mark.parametrize("model", [TableModel])
@pytest.mark.parametrize("region", [["sample_1"] * 5 + ["sample_2"] * 5])
def test_table_instance_key_values_not_unique(self, model: TableModel, region: str | np.ndarray):
Expand Down

0 comments on commit a2970d3

Please sign in to comment.