Skip to content

Commit

Permalink
v0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ucagenomix committed Feb 7, 2024
1 parent 253ebaa commit 9a327e8
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 79 deletions.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
.. autosummary::
:toctree: generated
pl.plot_shapes
pl.plot_shape_along_axis
pl.plot_qc
pl.plot_per_groups
Expand Down
4 changes: 3 additions & 1 deletion src/scispy/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .basic import get_palette, plot_per_groups, plot_qc, plot_shape_along_axis
from .basic import get_palette, plot_per_groups, plot_qc, plot_sdata, plot_shape_along_axis, plot_shapes

__all__ = [
"plot_shapes",
"plot_shape_along_axis",
"get_palette",
"plot_qc",
"plot_per_groups",
"plot_sdata",
]
160 changes: 119 additions & 41 deletions src/scispy/pl/basic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,77 @@
import math

import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import spatialdata as sd
import squidpy as sq
from matplotlib import pyplot as plt
from spatialdata.transformations import Affine, set_transformation

from scispy.tl.basic import sdata_rotate


def plot_shapes(
sdata: sd.SpatialData,
group_lst: tuple = [], # the cell types to consider
shapes_lst: tuple = [], # the shapes to plot
label_obs_key: str = "celltype_spatial",
shape_key: str = "arteries",
target_coordinates: str = "microns",
figsize: tuple = (12, 6),
save: bool = False,
):
"""Plot list of shapes
Parameters
----------
sdata
SpatialData object obtained by tl.get_sdata_polygon()
group_lst
group list to consider (related to label_obs_key)
shapes_lst
shapes list to plot
label_obs_key
label_key in sdata.table.obs to consider
shape_key
SpatialData shape element to consider
target_coordinates
target_coordinates system of sdata object
figsize
figure size
save
wether or not to save the figure
"""
region_key = sdata.table.uns["spatialdata_attrs"]["region"]

if group_lst is None:
group_lst = sdata.table.obs[label_obs_key].unique().tolist()

fig, axs = plt.subplots(ncols=len(shapes_lst), nrows=1, figsize=figsize)
for i in range(0, len(shapes_lst)):
poly = sdata[shape_key][sdata[shape_key].name == shapes_lst[i]].geometry.item()
sdata2 = sd.polygon_query(
sdata,
poly,
target_coordinate_system=target_coordinates,
filter_table=True,
points=False,
shapes=True,
images=True,
)

sdata2.pl.render_images().pl.show(ax=axs[i])
sdata2.pl.render_shapes(elements=shape_key, outline=True, fill_alpha=0.25, outline_color="red").pl.show(
ax=axs[i]
)
# sdata2.pl.render_shapes(elements=region_key, color=label_obs_key, groups=group_lst).pl.show(ax=axs[i])
# sdata2.pp.get_elements([region_key]).pl.render_shapes(color=label_obs_key, groups=group_lst).pl.show(ax=axs[i])
sdata2.pl.render_shapes(elements=region_key).pl.show(ax=axs[i])

axs[i].set_title(shapes_lst[i])
# if(i < len(shapes_lst)):
# axs[i].get_legend().remove()

plt.tight_layout()


def plot_shape_along_axis(
Expand Down Expand Up @@ -45,7 +110,6 @@ def plot_shape_along_axis(
horary rotation angle of the shape before computing along x axis
save
wether or not to save the figure
"""
if group_lst is None:
group_lst = sdata.table.obs[label_obs_key].unique().tolist()
Expand All @@ -62,37 +126,16 @@ def plot_shape_along_axis(
# images=True,
# )

# rotate the shape along x axis
if rotation_angle != 0:
theta = math.pi / (360 / rotation_angle)
# perform rotation of shape
rotation = Affine(
[
[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1],
],
input_axes=("x", "y"),
output_axes=("x", "y"),
)
# translation = Translation([0, 0], axes=("x", "y"))
# sequence = Sequence([rotation, translation])

elements = list(sdata.points.keys()) + list(sdata.shapes.keys())
for i in range(0, len(elements)):
set_transformation(sdata[elements[i]], rotation, to_coordinate_system=target_coordinates)

# convert to final coordinates
sdata3 = sdata.transform_to_coordinate_system(target_coordinates)
sdata2 = sdata_rotate(sdata, rotation_angle, target_coordinates)

# get elements key, probably need to do better here in the futur !!
dataset_id = sdata3.table.obs.dataset_id.unique().tolist()[0]
dataset_id = sdata2.table.obs.dataset_id.unique().tolist()[0]
sdata_transcript_key = dataset_id + "_transcripts"
sdata_polygon_key = dataset_id + "_polygons"

# compute dataframes
df_transcripts = sdata3[sdata_transcript_key].compute()
df_celltypes = sdata3[sdata_group_key].compute()
df_transcripts = sdata2[sdata_transcript_key].compute()
df_celltypes = sdata2[sdata_group_key].compute()

# parametrage
x_min = df_transcripts.x.min()
Expand Down Expand Up @@ -120,7 +163,7 @@ def plot_shape_along_axis(

valct = pd.DataFrame({"microns": [], "count": [], "cell_type": []})
for ct in range(0, len(group_lst)):
df2 = df_celltypes[df_celltypes[label_obs_key] == group_lst[ct]]
df2 = df_celltypes[df_celltypes["ct"] == group_lst[ct]]
for i in range(0, step_number):
new_row = {
"microns": (x_min + (i + 0.5) * bin_size),
Expand All @@ -136,7 +179,7 @@ def plot_shape_along_axis(

# draw figure
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, sharex=True)
sdata3.pl.render_shapes(
sdata.pl.render_shapes(
elements=sdata_polygon_key, color=label_obs_key, palette=list(mypal.values()), groups=group_lst
).pl.show(ax=ax1)

Expand Down Expand Up @@ -171,6 +214,26 @@ def plot_shape_along_axis(
plt.savefig(dataset_id + ".pdf", format="pdf", bbox_inches="tight")


def plot_sdata(
sdata: sd.SpatialData,
color_key: str = "celltype",
):
"""Plot sdata object (i.e. embedding and polygons). This should always works if well synchronized sdata object
Parameters
----------
sdata
SpatialData object.
color_key
color key from .table.obs
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
sc.pl.embedding(sdata.table, "umap", color=color_key, ax=ax1, show=False)
ax1.get_legend().remove()
sdata.pl.render_shapes(elements=sdata.table.uns["spatialdata_attrs"]["region"], color=color_key).pl.show(ax=ax2)
plt.tight_layout()


def get_palette(color_key: str) -> dict:
"""Palette definition for specific projects.
Expand Down Expand Up @@ -297,7 +360,7 @@ def plot_qc(sdata: sd.SpatialData):
plt.tight_layout()


def plot_per_groups(adata, clust_key, size=60, frameon=False, legend_loc=None, **kwargs):
def plot_per_groups(adata, clust_key, size=60, is_spatial=False, frameon=False, legend_loc=None, **kwargs):
"""Plot UMAP splitted by clust_key
Parameters
Expand All @@ -306,6 +369,8 @@ def plot_per_groups(adata, clust_key, size=60, frameon=False, legend_loc=None, *
Anndata object.
clust_key
key to plot
is_spatial
UMAP plot if False,
"""
tmp = adata.copy()
Expand All @@ -314,12 +379,25 @@ def plot_per_groups(adata, clust_key, size=60, frameon=False, legend_loc=None, *
tmp.obs[clust] = adata.obs[clust_key].isin([clust]).astype("category")
tmp.uns[clust + "_colors"] = ["#d3d3d3", adata.uns[clust_key + "_colors"][i]]

sc.pl.umap(
tmp,
groups=tmp.obs[clust].cat.categories[1:].values,
color=adata.obs[clust_key].cat.categories.tolist(),
size=size,
frameon=frameon,
legend_loc=legend_loc,
**kwargs,
)
if is_spatial is False:
sc.pl.umap(
tmp,
groups=tmp.obs[clust].cat.categories[1:].values,
color=adata.obs[clust_key].cat.categories.tolist(),
size=size,
frameon=frameon,
legend_loc=legend_loc,
**kwargs,
)
else:
# not working !....
tmp.uns["spatial"] = tmp.obsm["spatial"]
sq.pl.spatial_scatter(
tmp,
groups=tmp.obs[clust].cat.categories[1:].values,
color=adata.obs[clust_key].cat.categories.tolist(),
size=size,
frameon=frameon,
legend_loc=legend_loc,
**kwargs,
)
9 changes: 9 additions & 0 deletions src/scispy/pp/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def scvi_annotate(
label_key: str = "celltype",
layer: str = "counts",
metaref2add: str = None,
filter_under_score: float = 0.5,
):
"""Annotate anndata spatial cells using anndata cells reference using SCVI.
Expand All @@ -82,6 +83,8 @@ def scvi_annotate(
layer in which we can find the raw count values.
metaref2add
.obs key in single-cell reference object to transfert to spatial.
filter_under_score
remove cells having a scvi assignment score under this cutoff
"""
ad_spatial.var.index = ad_spatial.var.index.str.upper()
Expand Down Expand Up @@ -147,3 +150,9 @@ def scvi_annotate(
d = pd.Series(ad_ref.obs[f"{metaref2add}"].values, index=ad_ref.obs[f"{label_ref}"]).to_dict()
ad_spatial.obs[f"{metaref2add}"] = ad_spatial.obs[f"{label_key}"].map(d)
ad_spatial.obs[f"{metaref2add}"] = ad_spatial.obs[f"{metaref2add}"].astype("category")

# remove cells having a bad score
nb_cells = ad_spatial.shape[0]
ad_spatial = ad_spatial[ad_spatial.obs[f"{label_key}_score"] >= filter_under_score]
filtered_cells = nb_cells - ad_spatial.shape[0]
print("low assignment score filtering ", filtered_cells)
4 changes: 4 additions & 0 deletions src/scispy/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
get_sdata_polygon,
prep_pseudobulk,
run_pseudobulk,
sdata_querybox,
sdata_rotate,
)

__all__ = [
Expand All @@ -12,4 +14,6 @@
"get_sdata_polygon",
"prep_pseudobulk",
"run_pseudobulk",
"sdata_rotate",
"sdata_querybox",
]
Loading

0 comments on commit 9a327e8

Please sign in to comment.