Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test spatialelement table join #208

Merged
merged 35 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b648b4c
fixed all the non-cytassist images
LucaMarconato Mar 13, 2024
eb2639a
Merge branch 'multi_table' into test_for_join
LucaMarconato Mar 14, 2024
97eb5e8
refactor get_init_table_list
melonora Mar 14, 2024
017c5dc
fix permutation
melonora Mar 14, 2024
1d10bae
use join
melonora Mar 15, 2024
687ed18
use join for view
melonora Mar 15, 2024
53a1564
change hardcoded sample threshold
melonora Mar 15, 2024
6b21e6c
bunch of updates
melonora Mar 15, 2024
b0b7742
adjust test
melonora Mar 15, 2024
41540b6
change into left join
melonora Mar 16, 2024
51ab7d9
change into left join
melonora Mar 16, 2024
12e2ad0
random hex colors
melonora Mar 16, 2024
c51ded2
get actual index
melonora Mar 16, 2024
a648f06
change indices and adata
melonora Mar 17, 2024
0210c9a
match rows left
melonora Mar 17, 2024
7d08123
fix joins
melonora Mar 17, 2024
523e543
rename to vars
melonora Mar 18, 2024
fe9d5bf
fix color by
melonora Mar 18, 2024
33bf4ec
fix radii
melonora Mar 18, 2024
b898321
pass cs
melonora Mar 18, 2024
1fcc1c6
fix _calc_default_radii()
LucaMarconato Mar 18, 2024
061d48e
fix tests
melonora Mar 18, 2024
accde54
Merge branch 'test_for_join' of https://github.com/scverse/napari-spa…
melonora Mar 18, 2024
1205089
improve speed test
melonora Mar 18, 2024
c8e0416
removed comments
LucaMarconato Mar 18, 2024
54a0be9
added code to show circles as ellipses
LucaMarconato Mar 18, 2024
e93866b
Merge branch 'test_for_join' of https://github.com/scverse/napari-spa…
LucaMarconato Mar 18, 2024
97fd8e7
correct point layer size after scaling
LucaMarconato Mar 18, 2024
a39e535
fix labels color by
melonora Mar 18, 2024
b73a254
Merge branch 'test_for_join' of https://github.com/scverse/napari-spa…
melonora Mar 18, 2024
12a46b6
fix mibitof example
melonora Mar 18, 2024
4777285
last bug fixes
melonora Mar 19, 2024
473a26f
initial step points columns
melonora Mar 19, 2024
f90f9df
add dask import back
melonora Mar 19, 2024
3d62854
fix points columns
melonora Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.insertAction"),
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.scroll"),
# ("py:obj", "napari_spatialdata.QtAdataViewWidget.setTabOrder"),
# ("py:class", "napari_spatialdata._model.ImageModel"),
# ("py:class", "napari_spatialdata._model.DataModel"),
# ("py:obj", "napari_spatialdata.QtAdataScatterWidget.scroll"),
# ("py:obj", "napari_spatialdata.QtAdataScatterWidget.insertAction"),
# If building the documentation fails because of a missing link that is outside your control,
Expand Down
20 changes: 17 additions & 3 deletions src/napari_spatialdata/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
from napari_spatialdata._constants._constants import Symbol
from napari_spatialdata.utils._utils import NDArrayA, _ensure_dense_vector

__all__ = ["ImageModel"]
__all__ = ["DataModel"]


@dataclass
class ImageModel:
class DataModel:
"""Model which holds the data for interactive visualization."""

events: EmitterGroup = field(init=False, default=None, repr=True)
_table_names: Sequence[Optional[str]] = field(default_factory=list, init=False)
_active_table_name: Optional[str] = field(default=None, init=False, repr=True)
_layer: Layer = field(init=False, default=None, repr=True)
_adata: Optional[AnnData] = field(init=False, default=None, repr=True)
_adata_layer: Optional[str] = field(init=False, default=None, repr=False)
Expand Down Expand Up @@ -78,7 +79,12 @@ def get_obs(
"""
if name not in self.adata.obs.columns:
raise KeyError(f"Key `{name}` not found in `adata.obs`.")
return self.adata.obs[name], self._format_key(name)
if name != self.instance_key:
adata_obs = self.adata.obs[[self.instance_key, name]]
adata_obs.set_index(self.instance_key, inplace=True)
else:
adata_obs = self.adata.obs
return adata_obs[name], self._format_key(name)
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved

@_ensure_dense_vector
def get_var(self, name: Union[str, int], **_: Any) -> Tuple[Optional[NDArrayA], str]: # TODO(giovp): fix docstring
Expand Down Expand Up @@ -179,6 +185,14 @@ def layer(self, layer: Optional[Layer]) -> None:
self._layer = layer
self.events.layer()

@property
def active_table_name(self) -> Optional[str]:
return self._active_table_name

@active_table_name.setter
def active_table_name(self, active_table_name: Optional[str]) -> None:
self._active_table_name = active_table_name

@property
def adata(self) -> AnnData: # noqa: D102
return self._adata
Expand Down
12 changes: 6 additions & 6 deletions src/napari_spatialdata/_scatterwidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from qtpy import QtWidgets
from qtpy.QtCore import Signal

from napari_spatialdata._model import ImageModel
from napari_spatialdata._model import DataModel
from napari_spatialdata._widgets import AListWidget, ComponentWidget
from napari_spatialdata.utils._categoricals_utils import _add_categorical_legend
from napari_spatialdata.utils._utils import NDArrayA, _get_categorical, _set_palette
Expand Down Expand Up @@ -57,7 +57,7 @@ class SelectFromCollection:

def __init__(
self,
model: ImageModel,
model: DataModel,
ax: Axes,
collection: Collection,
data: list[NDArrayA],
Expand Down Expand Up @@ -113,7 +113,7 @@ class ScatterListWidget(AListWidget):
_text = None
_chosen = None

def __init__(self, model: ImageModel, attr: str, color: bool, **kwargs: Any):
def __init__(self, model: DataModel, attr: str, color: bool, **kwargs: Any):
AListWidget.__init__(self, None, model, attr, **kwargs)
self.attrChanged.connect(self._onChange)
self._color = color
Expand Down Expand Up @@ -209,7 +209,7 @@ def data(self, data: NDArrayA | dict[str, Any]) -> None:


class MatplotlibWidget(NapariMPLWidget):
def __init__(self, viewer: Viewer | None, model: ImageModel):
def __init__(self, viewer: Viewer | None, model: DataModel):
self.is_widget = False
if viewer is None:
viewer = Viewer()
Expand Down Expand Up @@ -291,7 +291,7 @@ def clear(self) -> None:


class AxisWidgets(QtWidgets.QWidget):
def __init__(self, model: ImageModel, name: str, color: bool = False):
def __init__(self, model: DataModel, name: str, color: bool = False):
super().__init__()

self._model = model
Expand Down Expand Up @@ -346,6 +346,6 @@ def clear(self) -> None:
self.component_widget.clear()

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model
63 changes: 25 additions & 38 deletions src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
QVBoxLayout,
QWidget,
)
from spatialdata import join_sdata_spatialelement_table

from napari_spatialdata._model import ImageModel
from napari_spatialdata._model import DataModel
from napari_spatialdata._scatterwidgets import AxisWidgets, MatplotlibWidget
from napari_spatialdata._widgets import (
AListWidget,
Expand All @@ -28,14 +29,16 @@

__all__ = ["QtAdataViewWidget", "QtAdataScatterWidget"]

from napari_spatialdata.utils._utils import _get_init_table_list


class QtAdataScatterWidget(QWidget):
"""Adata viewer widget."""

def __init__(self, input: Viewer):
super().__init__()

self._model = ImageModel()
self._model = DataModel()

self.setLayout(QGridLayout())

Expand Down Expand Up @@ -109,14 +112,13 @@ def export(self) -> None:
def _update_adata(self) -> None:
if (table_name := self.table_name_widget.currentText()) == "":
return
self.model.active_table_name = table_name
layer = self._viewer.layers.selection.active
adata = None

if sdata := layer.metadata.get("sdata"):
element_name = layer.metadata.get("name")
table = sdata[table_name]
adata = table[table.obs[table.uns["spatialdata_attrs"]["region_key"]] == element_name]
layer.metadata["adata"] = adata
_, table = join_sdata_spatialelement_table(sdata, element_name, table_name, "left")
layer.metadata["adata"] = table

if layer is not None and "adata" in layer.metadata:
with self.model.events.adata.blocker():
Expand All @@ -126,10 +128,10 @@ def _update_adata(self) -> None:
return

self.model.instance_key = layer.metadata["instance_key"] = (
adata.uns["spatialdata_attrs"]["instance_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["instance_key"] if table is not None else None
)
self.model.region_key = layer.metadata["region_key"] = (
adata.uns["spatialdata_attrs"]["region_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["region_key"] if table is not None else None
)
self.model.system_name = layer.metadata["name"] if "name" in layer.metadata else None

Expand All @@ -148,7 +150,7 @@ def _on_selection(self, event: Any) -> None:
self.table_name_widget.clear()
self.table_name_widget.clear()
if event.source == self.model or event.source.active:
table_list = self._get_init_table_list()
table_list = _get_init_table_list(self.viewer.layers.selection.active)
if table_list:
self.model.table_names = table_list
self.table_name_widget.addItems(table_list)
Expand All @@ -161,14 +163,6 @@ def _on_selection(self, event: Any) -> None:
self.color_widget.widget._onChange()
self.color_widget.component_widget._onChange()

def _get_init_table_list(self) -> Optional[Sequence[Optional[str]]]:
layer = self.viewer.layers.selection.active

table_names: Optional[Sequence[Optional[str]]]
if table_names := layer.metadata.get("table_names"):
return table_names # type: ignore[no-any-return]
return None

def _select_layer(self) -> None:
"""Napari layers."""
layer = self._viewer.layers.selection.active
Expand All @@ -193,7 +187,7 @@ def viewer(self) -> napari.Viewer:
return self._viewer

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model

Expand All @@ -205,7 +199,7 @@ def __init__(self, viewer: Viewer):
super().__init__()

self._viewer = viewer
self._model = ImageModel()
self._model = DataModel()

self._select_layer()
self._viewer.layers.selection.events.changed.connect(self._select_layer)
Expand All @@ -230,8 +224,8 @@ def __init__(self, viewer: Viewer):
self.layout().addWidget(self.obs_widget)

# gene
var_label = QLabel("Genes:")
var_label.setToolTip("Gene names from `adata.var_names` or `adata.raw.var_names`.")
var_label = QLabel("Vars:")
var_label.setToolTip("Names from `adata.var_names` or `adata.raw.var_names`.")
self.var_widget = AListWidget(self.viewer, self.model, attr="var")
self.var_widget.setAdataLayer("X")

Expand All @@ -245,6 +239,7 @@ def __init__(self, viewer: Viewer):

self.adata_layer_widget.currentTextChanged.connect(self.var_widget.setAdataLayer)

self.layout().addWidget(adata_layer_label)
self.layout().addWidget(self.adata_layer_widget)
self.layout().addWidget(var_label)
self.layout().addWidget(self.var_widget)
Expand Down Expand Up @@ -282,7 +277,7 @@ def _on_layer_update(self, event: Optional[Any] = None) -> None:

self.table_name_widget.clear()

table_list = self._get_init_table_list()
table_list = _get_init_table_list(self.viewer.layers.selection.active)
if table_list:
self.model.table_names = table_list
self.table_name_widget.addItems(table_list)
Expand Down Expand Up @@ -329,14 +324,15 @@ def _select_layer(self) -> None:
def _update_adata(self) -> None:
if (table_name := self.table_name_widget.currentText()) == "":
return
self.model.active_table_name = table_name

layer = self._viewer.layers.selection.active
adata = None

if sdata := layer.metadata.get("sdata"):
element_name = layer.metadata.get("name")
table = sdata[table_name]
adata = table[table.obs[table.uns["spatialdata_attrs"]["region_key"]] == element_name]
layer.metadata["adata"] = adata
how = "left" if isinstance(layer, Labels) else "inner"
_, table = join_sdata_spatialelement_table(sdata, element_name, table_name, how)
layer.metadata["adata"] = table

if layer is not None and "adata" in layer.metadata:
with self.model.events.adata.blocker():
Expand All @@ -346,10 +342,10 @@ def _update_adata(self) -> None:
return

self.model.instance_key = layer.metadata["instance_key"] = (
adata.uns["spatialdata_attrs"]["instance_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["instance_key"] if table is not None else None
)
self.model.region_key = layer.metadata["region_key"] = (
adata.uns["spatialdata_attrs"]["region_key"] if adata is not None else None
table.uns["spatialdata_attrs"]["region_key"] if table is not None else None
)
self.model.system_name = layer.metadata["name"] if "name" in layer.metadata else None

Expand All @@ -371,15 +367,6 @@ def _get_adata_layer(self) -> Sequence[Optional[str]]:
return adata_layers
return [None]

def _get_init_table_list(self) -> Optional[Sequence[Optional[str]]]:
layer = self.viewer.layers.selection.active

table_names: Optional[Sequence[Optional[str]]]
if table_names := layer.metadata.get("table_names"):
return table_names # type: ignore[no-any-return]

return None

def _change_color_by(self) -> None:
self.color_by.setText(f"Color by: {self.model.color_by}")

Expand All @@ -389,7 +376,7 @@ def viewer(self) -> napari.Viewer:
return self._viewer

@property
def model(self) -> ImageModel:
def model(self) -> DataModel:
""":mod:`napari` viewer."""
return self._model

Expand Down
32 changes: 19 additions & 13 deletions src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from napari.utils.notifications import show_info
from qtpy.QtCore import QObject, Signal
from shapely import Polygon
from spatialdata._core.query.relational_query import _get_element_annotators
from spatialdata._core.query.relational_query import _get_element_annotators, _left_join_spatialelement_table
from spatialdata.models import PointsModel, ShapesModel
from spatialdata.transformations import Affine, Identity
from spatialdata.transformations._utils import scale_radii

from napari_spatialdata.utils._utils import (
_adjust_channels_order,
_calc_default_radii,
_get_init_metadata_adata,
_get_transform,
_transform_coordinates,
Expand Down Expand Up @@ -276,7 +277,9 @@ def inherit_metadata(self, layers: list[Layer], ref_layer: Layer, show_tooltip:

show_info(f"Layer(s) inherited info from {ref_layer}")

def _get_table_data(self, sdata: SpatialData, element_name: str) -> tuple[AnnData, str | None, list[str | None]]:
def _get_table_data(
self, sdata: SpatialData, element_name: str
) -> tuple[AnnData | None, str | None, list[str | None]]:
table_names = list(_get_element_annotators(sdata, element_name))
table_name = table_names[0] if len(table_names) > 0 else None
adata = _get_init_metadata_adata(sdata, table_name, element_name)
Expand Down Expand Up @@ -417,35 +420,39 @@ def add_sdata_points(self, sdata: SpatialData, key: str, selected_cs: str, multi

points = sdata.points[original_name].compute()
affine = _get_transform(sdata.points[original_name], selected_cs)
adata, table_name, table_names = self._get_table_data(sdata, original_name)
if len(points) < POINT_THRESHOLD:
subsample = np.arange(len(points))
subsample = None
else:
logger.info("Subsampling points because the number of points exceeds the currently supported 100 000.")
gen = np.random.default_rng()
subsample = np.sort(gen.choice(len(points), size=100000, replace=False)) # same as indices

# TODO consider subsampling adata and passing that on.
adata, table_name, table_names = self._get_table_data(sdata, original_name)
subsample = np.sort(gen.choice(len(points), size=POINT_THRESHOLD, replace=False)) # same as indices

xy = points[["y", "x"]].values[subsample]
subsample_points = points.iloc[subsample] if subsample else points
if subsample:
adata = _left_join_spatialelement_table(
{"points": {original_name: subsample_points}}, table_name, match_rows="left"
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
)
xy = subsample_points[["y", "x"]].values
np.fliplr(xy)
radii_size = _calc_default_radii(self.viewer, sdata, selected_cs)
layer = self.viewer.add_points(
xy,
name=key,
size=20,
size=radii_size,
affine=affine,
edge_width=0.0,
metadata={
"sdata": sdata,
"adata": AnnData(obs=points.iloc[subsample, :]),
"adata": adata,
"name": original_name,
"region_key": sdata[table_name].uns["spatialdata_attrs"]["region_key"] if table_name else None,
"instance_key": sdata[table_name].uns["spatialdata_attrs"]["instance_key"] if table_name else None,
"table_names": table_names if table_name else None,
"_active_in_cs": {selected_cs},
"_current_cs": selected_cs,
"_n_indices": len(points),
"indices": subsample.tolist(),
"indices": subsample_points.index.to_list(),
},
)
assert affine is not None
Expand All @@ -461,8 +468,7 @@ def _adjust_radii_of_points_layer(self, layer: Layer, affine: npt.ArrayLike) ->
else:
raise ValueError(f"Invalid affine shape: {affine.shape}")
affine_transformation = Affine(affine, input_axes=axes, output_axes=axes)
metadata = layer.metadata
radii = metadata["sdata"][metadata["name"]].radius.to_numpy()
radii = layer.size # TODO fix scale to radii
new_radii = scale_radii(radii=radii, affine=affine_transformation, axes=axes)
layer.size = new_radii * 2

Expand Down
Loading