Skip to content

Commit

Permalink
Test spatialelement table join (#208)
Browse files Browse the repository at this point in the history
* fixed all the non-cytassist images

* refactor get_init_table_list

* fix permutation

* use join

* use join for view

* change hardcoded sample threshold

* bunch of updates

* adjust test

* change into left join

* change into left join

* random hex colors

* get actual index

* change indices and adata

* match rows left

* fix joins

* rename to vars

* fix color by

* fix radii

* pass cs

* fix _calc_default_radii()

* fix tests

* improve speed test

* removed comments

* added code to show circles as ellipses

* correct point layer size after scaling

* fix labels color by

* fix mibitof example

* last bug fixes

* initial step points columns

* add dask import back

* fix points columns

---------

Co-authored-by: wmv_hpomen <[email protected]>
Co-authored-by: Wouter-Michiel Vierdag <[email protected]>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent cb1f6c4 commit 8f9ac6a
Show file tree
Hide file tree
Showing 15 changed files with 514 additions and 196 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.insertAction"),
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.scroll"),
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.setTabOrder"),
# ("py:class", "napari_spatialdata._model.ImageModel"),
# ("py:class", "napari_spatialdata._model.DataModel"),
# ("py:obj", "napari_spatialdata.QtAdataScatterWidget.scroll"),
# ("py:obj", "napari_spatialdata.QtAdataScatterWidget.insertAction"),
# If building the documentation fails because of a missing link that is outside your control,
Expand Down
2 changes: 2 additions & 0 deletions src/napari_spatialdata/_constants/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
POLYGON_THRESHOLD = 100
POINT_THRESHOLD = 100000
28 changes: 25 additions & 3 deletions src/napari_spatialdata/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from napari_spatialdata._constants._constants import Symbol
from napari_spatialdata.utils._utils import NDArrayA, _ensure_dense_vector

__all__ = ["ImageModel"]
__all__ = ["DataModel"]


@dataclass
class ImageModel:
class DataModel:
"""Model which holds the data for interactive visualization."""

events: EmitterGroup = field(init=False, default=None, repr=True)
_table_names: Sequence[Optional[str]] = field(default_factory=list, init=False)
_active_table_name: Optional[str] = field(default=None, init=False, repr=True)
_layer: Layer = field(init=False, default=None, repr=True)
_adata: Optional[AnnData] = field(init=False, default=None, repr=True)
_adata_layer: Optional[str] = field(init=False, default=None, repr=False)
Expand Down Expand Up @@ -58,6 +59,8 @@ def get_items(self, attr: str) -> Tuple[str, ...]:
"""
if attr in ("obs", "obsm"):
return tuple(map(str, getattr(self.adata, attr).keys()))
if attr == "points" and self.layer is not None and (point_cols := self.layer.metadata.get("points_columns")):
return tuple(map(str, point_cols.columns))
return tuple(map(str, getattr(self.adata, attr).index))

@_ensure_dense_vector
Expand All @@ -78,7 +81,18 @@ def get_obs(
"""
if name not in self.adata.obs.columns:
raise KeyError(f"Key `{name}` not found in `adata.obs`.")
return self.adata.obs[name], self._format_key(name)
if name != self.instance_key:
adata_obs = self.adata.obs[[self.instance_key, name]]
adata_obs.set_index(self.instance_key, inplace=True)
else:
adata_obs = self.adata.obs
return adata_obs[name], self._format_key(name)

@_ensure_dense_vector
def get_points(self, name: Union[str, int], **_: Any) -> Tuple[Optional[NDArrayA], str]:
if self.layer is None:
raise ValueError("Layer must be present")
return self.layer.metadata["points_columns"][name], self._format_key(name)

@_ensure_dense_vector
def get_var(self, name: Union[str, int], **_: Any) -> Tuple[Optional[NDArrayA], str]: # TODO(giovp): fix docstring
Expand Down Expand Up @@ -179,6 +193,14 @@ def layer(self, layer: Optional[Layer]) -> None:
self._layer = layer
self.events.layer()

@property
def active_table_name(self) -> Optional[str]:
return self._active_table_name

@active_table_name.setter
def active_table_name(self, active_table_name: Optional[str]) -> None:
self._active_table_name = active_table_name

@property
def adata(self) -> AnnData: # noqa: D102
return self._adata
Expand Down
12 changes: 6 additions & 6 deletions src/napari_spatialdata/_scatterwidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from qtpy import QtWidgets
from qtpy.QtCore import Signal

from napari_spatialdata._model import ImageModel
from napari_spatialdata._model import DataModel
from napari_spatialdata._widgets import AListWidget, ComponentWidget
from napari_spatialdata.utils._categoricals_utils import _add_categorical_legend
from napari_spatialdata.utils._utils import NDArrayA, _get_categorical, _set_palette
Expand Down Expand Up @@ -57,7 +57,7 @@ class SelectFromCollection:

def __init__(
self,
model: ImageModel,
model: DataModel,
ax: Axes,
collection: Collection,
data: list[NDArrayA],
Expand Down Expand Up @@ -113,7 +113,7 @@ class ScatterListWidget(AListWidget):
_text = None
_chosen = None

def __init__(self, model: ImageModel, attr: str, color: bool, **kwargs: Any):
def __init__(self, model: DataModel, attr: str, color: bool, **kwargs: Any):
AListWidget.__init__(self, None, model, attr, **kwargs)
self.attrChanged.connect(self._onChange)
self._color = color
Expand Down Expand Up @@ -209,7 +209,7 @@ def data(self, data: NDArrayA | dict[str, Any]) -> None:


class MatplotlibWidget(NapariMPLWidget):
def __init__(self, viewer: Viewer | None, model: ImageModel):
def __init__(self, viewer: Viewer | None, model: DataModel):
self.is_widget = False
if viewer is None:
viewer = Viewer()
Expand Down Expand Up @@ -291,7 +291,7 @@ def clear(self) -> None:


class AxisWidgets(QtWidgets.QWidget):
def __init__(self, model: ImageModel, name: str, color: bool = False):
def __init__(self, model: DataModel, name: str, color: bool = False):
super().__init__()

self._model = model
Expand Down Expand Up @@ -346,6 +346,6 @@ def clear(self) -> None:
self.component_widget.clear()

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model
81 changes: 41 additions & 40 deletions src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger
from napari._qt.qt_resources import get_stylesheet
from napari._qt.utils import QImg2array
from napari.layers import Labels
from napari.layers import Labels, Points
from napari.viewer import Viewer
from qtpy.QtCore import QSize, Qt
from qtpy.QtWidgets import (
Expand All @@ -16,8 +16,9 @@
QVBoxLayout,
QWidget,
)
from spatialdata import join_sdata_spatialelement_table

from napari_spatialdata._model import ImageModel
from napari_spatialdata._model import DataModel
from napari_spatialdata._scatterwidgets import AxisWidgets, MatplotlibWidget
from napari_spatialdata._widgets import (
AListWidget,
Expand All @@ -28,14 +29,16 @@

__all__ = ["QtAdataViewWidget", "QtAdataScatterWidget"]

from napari_spatialdata.utils._utils import _get_init_table_list


class QtAdataScatterWidget(QWidget):
"""Adata viewer widget."""

def __init__(self, input: Viewer):
super().__init__()

self._model = ImageModel()
self._model = DataModel()

self.setLayout(QGridLayout())

Expand Down Expand Up @@ -109,14 +112,13 @@ def export(self) -> None:
def _update_adata(self) -> None:
if (table_name := self.table_name_widget.currentText()) == "":
return
self.model.active_table_name = table_name
layer = self._viewer.layers.selection.active
adata = None

if sdata := layer.metadata.get("sdata"):
element_name = layer.metadata.get("name")
table = sdata[table_name]
adata = table[table.obs[table.uns["spatialdata_attrs"]["region_key"]] == element_name]
layer.metadata["adata"] = adata
_, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "left")
layer.metadata["adata"] = table

if layer is not None and "adata" in layer.metadata:
with self.model.events.adata.blocker():
Expand All @@ -126,10 +128,10 @@ def _update_adata(self) -> None:
return

self.model.instance_key = layer.metadata["instance_key"] = (
adata.uns["spatialdata_attrs"]["instance_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["instance_key"] if table is not None else None
)
self.model.region_key = layer.metadata["region_key"] = (
adata.uns["spatialdata_attrs"]["region_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["region_key"] if table is not None else None
)
self.model.system_name = layer.metadata["name"] if "name" in layer.metadata else None

Expand All @@ -148,7 +150,7 @@ def _on_selection(self, event: Any) -> None:
self.table_name_widget.clear()
self.table_name_widget.clear()
if event.source == self.model or event.source.active:
table_list = self._get_init_table_list()
table_list = _get_init_table_list(self.viewer.layers.selection.active)
if table_list:
self.model.table_names = table_list
self.table_name_widget.addItems(table_list)
Expand All @@ -161,14 +163,6 @@ def _on_selection(self, event: Any) -> None:
self.color_widget.widget._onChange()
self.color_widget.component_widget._onChange()

def _get_init_table_list(self) -> Optional[Sequence[Optional[str]]]:
layer = self.viewer.layers.selection.active

table_names: Optional[Sequence[Optional[str]]]
if table_names := layer.metadata.get("table_names"):
return table_names # type: ignore[no-any-return]
return None

def _select_layer(self) -> None:
"""Napari layers."""
layer = self._viewer.layers.selection.active
Expand All @@ -193,7 +187,7 @@ def viewer(self) -> napari.Viewer:
return self._viewer

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model

Expand All @@ -205,7 +199,7 @@ def __init__(self, viewer: Viewer):
super().__init__()

self._viewer = viewer
self._model = ImageModel()
self._model = DataModel()

self._select_layer()
self._viewer.layers.selection.events.changed.connect(self._select_layer)
Expand All @@ -229,9 +223,9 @@ def __init__(self, viewer: Viewer):
self.layout().addWidget(obs_label)
self.layout().addWidget(self.obs_widget)

# gene
var_label = QLabel("Genes:")
var_label.setToolTip("Gene names from `adata.var_names` or `adata.raw.var_names`.")
# Vars
var_label = QLabel("Vars:")
var_label.setToolTip("Names from `adata.var_names` or `adata.raw.var_names`.")
self.var_widget = AListWidget(self.viewer, self.model, attr="var")
self.var_widget.setAdataLayer("X")

Expand All @@ -245,6 +239,7 @@ def __init__(self, viewer: Viewer):

self.adata_layer_widget.currentTextChanged.connect(self.var_widget.setAdataLayer)

self.layout().addWidget(adata_layer_label)
self.layout().addWidget(self.adata_layer_widget)
self.layout().addWidget(var_label)
self.layout().addWidget(self.var_widget)
Expand All @@ -262,6 +257,13 @@ def __init__(self, viewer: Viewer):
self.layout().addWidget(self.obsm_widget)
self.layout().addWidget(self.obsm_index_widget)

# Points columns
points_label = QLabel("Points columns:")
points_label.setToolTip("Columns in points element excluding dimension columns.")
self.points_widget = AListWidget(self.viewer, self.model, attr="points", multiselect=False)
self.layout().addWidget(points_label)
self.layout().addWidget(self.points_widget)

# color by
self.color_by = QLabel("Colored by:")
self.layout().addWidget(self.color_by)
Expand All @@ -282,7 +284,7 @@ def _on_layer_update(self, event: Optional[Any] = None) -> None:

self.table_name_widget.clear()

table_list = self._get_init_table_list()
table_list = _get_init_table_list(self.viewer.layers.selection.active)
if table_list:
self.model.table_names = table_list
self.table_name_widget.addItems(table_list)
Expand All @@ -291,6 +293,7 @@ def _on_layer_update(self, event: Optional[Any] = None) -> None:
self.adata_layer_widget.clear()
self.adata_layer_widget.addItem("X", None)
self.adata_layer_widget.addItems(self._get_adata_layer())
self.points_widget.clear()
self.obs_widget._onChange()
self.var_widget._onChange()
self.obsm_widget._onChange()
Expand All @@ -303,10 +306,16 @@ def _select_layer(self) -> None:
if hasattr(self, "obs_widget"):
self.table_name_widget.clear()
self.adata_layer_widget.clear()
self.points_widget.clear()
self.obs_widget.clear()
self.var_widget.clear()
self.obsm_widget.clear()
self.color_by.clear()
if (
isinstance(layer, Points)
and len(cols := layer.metadata["sdata"][layer.metadata["name"]].columns.drop(["x", "y"])) != 0
):
self.points_widget.addItems(map(str, cols))
return

if layer is not None and "adata" in layer.metadata:
Expand All @@ -329,14 +338,15 @@ def _select_layer(self) -> None:
def _update_adata(self) -> None:
if (table_name := self.table_name_widget.currentText()) == "":
return
self.model.active_table_name = table_name

layer = self._viewer.layers.selection.active
adata = None

if sdata := layer.metadata.get("sdata"):
element_name = layer.metadata.get("name")
table = sdata[table_name]
adata = table[table.obs[table.uns["spatialdata_attrs"]["region_key"]] == element_name]
layer.metadata["adata"] = adata
how = "left" if isinstance(layer, Labels) else "inner"
_, table = join_sdata_spatialelement_table(sdata, element_name, table_name, how)
layer.metadata["adata"] = table

if layer is not None and "adata" in layer.metadata:
with self.model.events.adata.blocker():
Expand All @@ -346,10 +356,10 @@ def _update_adata(self) -> None:
return

self.model.instance_key = layer.metadata["instance_key"] = (
adata.uns["spatialdata_attrs"]["instance_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["instance_key"] if table is not None else None
)
self.model.region_key = layer.metadata["region_key"] = (
adata.uns["spatialdata_attrs"]["region_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["region_key"] if table is not None else None
)
self.model.system_name = layer.metadata["name"] if "name" in layer.metadata else None

Expand All @@ -371,15 +381,6 @@ def _get_adata_layer(self) -> Sequence[Optional[str]]:
return adata_layers
return [None]

def _get_init_table_list(self) -> Optional[Sequence[Optional[str]]]:
layer = self.viewer.layers.selection.active

table_names: Optional[Sequence[Optional[str]]]
if table_names := layer.metadata.get("table_names"):
return table_names # type: ignore[no-any-return]

return None

def _change_color_by(self) -> None:
self.color_by.setText(f"Color by: {self.model.color_by}")

Expand All @@ -389,7 +390,7 @@ def viewer(self) -> napari.Viewer:
return self._viewer

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model

Expand Down
Loading

0 comments on commit 8f9ac6a

Please sign in to comment.