Skip to content

Commit

Permalink
[Backport release-1.15][python] Rename to_spatial_data -> `to_spati…
Browse files Browse the repository at this point in the history
…aldata` (#3451)

* merge

* fix merge

* fix bad _query.py merge; rebase

---------

Co-authored-by: Julia Dark <[email protected]>
  • Loading branch information
johnkerl and jp-dark authored Dec 17, 2024
1 parent 41dbfff commit 59ed899
Show file tree
Hide file tree
Showing 8 changed files with 475 additions and 42 deletions.
102 changes: 97 additions & 5 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Implementation of a SOMA Experiment.
"""
import enum
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -54,6 +55,7 @@

if TYPE_CHECKING:
from ._experiment import Experiment
from ._constants import SPATIAL_DISCLAIMER
from ._fastercsx import CompressedMatrix
from ._measurement import Measurement
from ._sparse_nd_array import SparseNDArray
Expand Down Expand Up @@ -407,9 +409,14 @@ def obs_scene_ids(self) -> pa.Array:
try:
obs_scene = self.experiment.obs_spatial_presence
except KeyError as ke:
raise KeyError("Missing obs_scene") from ke
raise KeyError(
"No obs_spatial_presence dataframe in this experiment."
) from ke
if not isinstance(obs_scene, DataFrame):
raise TypeError("obs_scene must be a dataframe.")
raise TypeError(
f"obs_spatial_presence must be a dataframe; got "
f"{type(obs_scene).__name__}."
)

full_table = obs_scene.read(
coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
Expand All @@ -428,12 +435,18 @@ def var_scene_ids(self) -> pa.Array:
try:
var_scene = self._ms.var_spatial_presence
except KeyError as ke:
raise KeyError("Missing var_scene") from ke
raise KeyError(
f"No var_spatial_presence dataframe in measurement "
f"'{self.measurement_name}'."
) from ke
if not isinstance(var_scene, DataFrame):
raise TypeError("var_scene must be a dataframe.")
raise TypeError(
f"var_spatial_presence must be a dataframe; got "
f"{type(var_scene).__name__}."
)

full_table = var_scene.read(
coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
coords=((Axis.VAR.getattr_from(self._joinids), slice(None))),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand Down Expand Up @@ -473,6 +486,85 @@ def to_anndata(

return ad

def to_spatialdata( # type: ignore[no-untyped-def]
self,
X_name: str,
*,
column_names: Optional[AxisColumnNames] = None,
X_layers: Sequence[str] = (),
obsm_layers: Sequence[str] = (),
obsp_layers: Sequence[str] = (),
varm_layers: Sequence[str] = (),
varp_layers: Sequence[str] = (),
drop_levels: bool = False,
scene_presence_mode: str = "obs",
):
"""Returns a SpatialData object containing the query results
This is a low-level routine intended to be used by loaders for other
in-core formats, such as AnnData, which can be created from the
resulting objects.
Args:
X_name: The X layer to read and return in the ``X`` slot.
column_names: The columns in the ``var`` and ``obs`` dataframes
to read.
X_layers: Additional X layers to read and return in the ``layers`` slot.
obsm_layers: Additional obsm layers to read and return in the obsm slot.
obsp_layers: Additional obsp layers to read and return in the obsp slot.
varm_layers: Additional varm layers to read and return in the varm slot.
varp_layers: Additional varp layers to read and return in the varp slot.
drop_levels: If ``True`` remove unused categories from the AnnData table
with the measurement data.
scene_presence_mode: Method for determining what scenes to return data
from. Valid options are ``obs`` (use ``obs_spatial_presence``
dataframe) and ``var`` (use ``var_spatial_presence`` dataframe).
Defaults to ``obs``.
"""

from spatialdata import SpatialData

from .io.spatial.outgest import _add_scene_to_spatialdata

warnings.warn(SPATIAL_DISCLAIMER)

# Get a list of scenes to add to SpatialData object.
if scene_presence_mode == "obs":
scene_ids = self.obs_scene_ids()
elif scene_presence_mode == "var":
scene_ids = self.var_scene_ids()
else:
raise ValueError(
f"Invalid scene presence mode '{scene_presence_mode}'. Valid options "
f"are 'obs' and 'var'."
)

# Get the anndata table.
ad = self.to_anndata(
X_name,
column_names=column_names,
X_layers=X_layers,
obsm_layers=obsm_layers,
obsp_layers=obsp_layers,
varm_layers=varm_layers,
varp_layers=varp_layers,
drop_levels=drop_levels,
)
sdata = SpatialData(tables={self.measurement_name: ad})

for scene_id in scene_ids:
scene = self.experiment.spatial[str(scene_id)]
_add_scene_to_spatialdata(
sdata,
scene_id=str(scene_id),
scene=scene,
obs_id_name="soma_joinid",
var_id_name="soma_joinid",
measurement_names=(self.measurement_name,),
)

return sdata

# Context management

def __enter__(self) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions apis/python/src/tiledbsoma/io/spatial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"""

from .ingest import VisiumPaths, from_visium
from .outgest import to_spatial_data
from .outgest import to_spatialdata

__all__ = ["to_spatial_data", "from_visium", "VisiumPaths"]
__all__ = ["to_spatialdata", "from_visium", "VisiumPaths"]
48 changes: 24 additions & 24 deletions apis/python/src/tiledbsoma/io/spatial/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _convert_axis_names(
return spatial_data_axes, soma_dim_map


def _transform_to_spatial_data(
def _transform_to_spatialdata(
transform: somacore.CoordinateTransform,
input_dim_map: Dict[str, str],
output_dim_map: Dict[str, str],
Expand Down Expand Up @@ -102,7 +102,7 @@ def _transform_to_spatial_data(
)


def to_spatial_data_points(
def to_spatialdata_points(
points: PointCloudDataFrame,
*,
key: str,
Expand Down Expand Up @@ -135,7 +135,7 @@ def to_spatial_data_points(
transforms = {key: sd.transformations.Identity()}
else:
transforms = {
scene_id: _transform_to_spatial_data(
scene_id: _transform_to_spatialdata(
transform.inverse_transform(), points_dim_map, scene_dim_map
)
}
Expand All @@ -146,7 +146,7 @@ def to_spatial_data_points(
return sd.models.PointsModel.parse(df, transformations=transforms)


def to_spatial_data_shapes(
def to_spatialdata_shapes(
points: PointCloudDataFrame,
*,
key: str,
Expand Down Expand Up @@ -197,7 +197,7 @@ def to_spatial_data_shapes(
transforms = {key: sd.transformations.Identity()}
else:
transforms = {
scene_id: _transform_to_spatial_data(
scene_id: _transform_to_spatialdata(
transform.inverse_transform(), points_dim_map, scene_dim_map
)
}
Expand All @@ -220,7 +220,7 @@ def to_spatial_data_shapes(
return df


def to_spatial_data_image(
def to_spatialdata_image(
image: MultiscaleImage,
level: Optional[Union[str, int]] = None,
*,
Expand Down Expand Up @@ -272,7 +272,7 @@ def to_spatial_data_image(
if transform is None:
# Get the transformation from the image level to the highest resolution of the multiscale image.
scale_transform = image.get_transform_from_level(level)
sd_transform = _transform_to_spatial_data(
sd_transform = _transform_to_spatialdata(
scale_transform, image_dim_map, image_dim_map
)
transformations = {key: sd_transform}
Expand All @@ -287,14 +287,14 @@ def to_spatial_data_image(
scale_transform, somacore.IdentityTransform
):
# inv_transform @ scale_transform -> applies scale_transform first
sd_transform = _transform_to_spatial_data(
sd_transform = _transform_to_spatialdata(
inv_transform @ scale_transform, image_dim_map, scene_dim_map
)
else:
sd_transform1 = _transform_to_spatial_data(
sd_transform1 = _transform_to_spatialdata(
scale_transform, image_dim_map, image_dim_map
)
sd_transform2 = _transform_to_spatial_data(
sd_transform2 = _transform_to_spatialdata(
inv_transform, image_dim_map, scene_dim_map
)
# Sequence([sd_transform1, sd_transform2]) -> applies sd_transform1 first
Expand All @@ -310,7 +310,7 @@ def to_spatial_data_image(
)


def to_spatial_data_multiscale_image(
def to_spatialdata_multiscale_image(
image: MultiscaleImage,
*,
key: str,
Expand Down Expand Up @@ -350,7 +350,7 @@ def to_spatial_data_multiscale_image(

if transform is None:
spatial_data_transformations = tuple(
_transform_to_spatial_data(
_transform_to_spatialdata(
image.get_transform_from_level(level),
image_dim_map,
image_dim_map,
Expand All @@ -366,7 +366,7 @@ def to_spatial_data_multiscale_image(
if isinstance(transform, somacore.ScaleTransform):
# inv_transform @ scale_transform -> applies scale_transform first
spatial_data_transformations = tuple(
_transform_to_spatial_data(
_transform_to_spatialdata(
inv_transform @ image.get_transform_from_level(level),
image_dim_map,
scene_dim_map,
Expand All @@ -375,12 +375,12 @@ def to_spatial_data_multiscale_image(
)
else:
sd_scale_transforms = tuple(
_transform_to_spatial_data(
_transform_to_spatialdata(
image.get_transform_from_level(level), image_dim_map, image_dim_map
)
for level in range(1, image.level_count)
)
sd_inv_transform = _transform_to_spatial_data(
sd_inv_transform = _transform_to_spatialdata(
inv_transform, image_dim_map, scene_dim_map
)

Expand Down Expand Up @@ -420,7 +420,7 @@ def _get_transform_from_collection(
return None


def _add_scene_to_spatial_data(
def _add_scene_to_spatialdata(
sdata: sd.SpatialData,
scene_id: str,
scene: Scene,
Expand Down Expand Up @@ -456,7 +456,7 @@ def _add_scene_to_spatial_data(
transform = _get_transform_from_collection(key, scene.obsl.metadata)
if isinstance(df, PointCloudDataFrame):
if "soma_geometry" in df.metadata:
sdata.shapes[output_key] = to_spatial_data_shapes(
sdata.shapes[output_key] = to_spatialdata_shapes(
df,
key=output_key,
scene_id=scene_id,
Expand All @@ -465,7 +465,7 @@ def _add_scene_to_spatial_data(
soma_joinid_name=obs_id_name,
)
else:
sdata.points[output_key] = to_spatial_data_points(
sdata.points[output_key] = to_spatialdata_points(
df,
key=output_key,
scene_id=scene_id,
Expand All @@ -492,7 +492,7 @@ def _add_scene_to_spatial_data(
transform = _get_transform_from_collection(key, subcoll.metadata)
if isinstance(df, PointCloudDataFrame):
if "soma_geometry" in df.metadata:
sdata.shapes[output_key] = to_spatial_data_shapes(
sdata.shapes[output_key] = to_spatialdata_shapes(
df,
key=output_key,
scene_id=scene_id,
Expand All @@ -501,7 +501,7 @@ def _add_scene_to_spatial_data(
soma_joinid_name=var_id_name,
)
else:
sdata.points[output_key] = to_spatial_data_points(
sdata.points[output_key] = to_spatialdata_points(
df,
key=output_key,
scene_id=scene_id,
Expand All @@ -526,7 +526,7 @@ def _add_scene_to_spatial_data(
f"datatype {type(image).__name__}."
)
if image.level_count == 1:
sdata.images[output_key] = to_spatial_data_image(
sdata.images[output_key] = to_spatialdata_image(
image,
0,
key=output_key,
Expand All @@ -535,7 +535,7 @@ def _add_scene_to_spatial_data(
transform=transform,
)
else:
sdata.images[f"{scene_id}_{key}"] = to_spatial_data_multiscale_image(
sdata.images[f"{scene_id}_{key}"] = to_spatialdata_multiscale_image(
image,
key=output_key,
scene_id=scene_id,
Expand All @@ -544,7 +544,7 @@ def _add_scene_to_spatial_data(
)


def to_spatial_data(
def to_spatialdata(
experiment: Experiment,
*,
measurement_names: Optional[Sequence[str]] = None,
Expand Down Expand Up @@ -614,7 +614,7 @@ def to_spatial_data(

for scene_id in scene_names:
scene = experiment.spatial[scene_id]
_add_scene_to_spatial_data(
_add_scene_to_spatialdata(
sdata=sdata,
scene_id=scene_id,
scene=scene,
Expand Down
6 changes: 3 additions & 3 deletions apis/python/tests/test_basic_spatialdata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def sample_2d_data():

@pytest.fixture(scope="module")
def experiment_with_single_scene(tmp_path_factory, sample_2d_data) -> soma.Experiment:
uri = tmp_path_factory.mktemp("experiment_with_spatial_data").as_uri()
uri = tmp_path_factory.mktemp("experiment_with_spatialdata").as_uri()
with soma.Experiment.create(uri) as exp:
assert exp.uri == uri
# Create spatial folder.
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_outgest_no_spatial(tmp_path, conftest_pbmc_small):

# Read full experiment into SpatialData.
with _factory.open(output_path) as exp:
sdata = spatial_outgest.to_spatial_data(exp)
sdata = spatial_outgest.to_spatialdata(exp)

# Check the number of assets (exactly 1 table) is as expected.
assert len(sdata.tables) == 2
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_outgest_no_spatial(tmp_path, conftest_pbmc_small):

def test_outgest_spatial_only(experiment_with_single_scene, sample_2d_data):
# Export to SpatialData.
sdata = spatial_outgest.to_spatial_data(experiment_with_single_scene)
sdata = spatial_outgest.to_spatialdata(experiment_with_single_scene)

# Check the number of assets is correct.
assert len(sdata.tables) == 0
Expand Down
Loading

0 comments on commit 59ed899

Please sign in to comment.