Skip to content

Commit

Permalink
v0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ucagenomix committed Mar 1, 2024
1 parent 15df0ab commit 0de348b
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 71 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"spatialdata",
"spatialdata-io",
"spatialdata-plot",
"napari-spatialdata",
#"napari-spatialdata",
"squidpy",
#"torch==1.13.1",

Expand Down
13 changes: 12 additions & 1 deletion src/scispy/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
from .basic import get_palette, plot_per_groups, plot_qc, plot_sdata, plot_shape_along_axis, plot_shapes
from .basic import (
get_palette,
legend_without_duplicate_labels,
plot_multi_sdata,
plot_per_groups,
plot_qc,
plot_sdata,
plot_shape_along_axis,
plot_shapes,
)

__all__ = [
"plot_shapes",
Expand All @@ -7,4 +16,6 @@
"plot_qc",
"plot_per_groups",
"plot_sdata",
"plot_multi_sdata",
"legend_without_duplicate_labels",
]
169 changes: 124 additions & 45 deletions src/scispy/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def plot_shapes(
shapes_lst: tuple = [], # the shapes to plot
label_obs_key: str = "celltype_spatial",
shape_key: str = "arteries",
color_dict: tuple = [],
target_coordinates: str = "microns",
figsize: tuple = (12, 6),
save: bool = False,
Expand All @@ -33,6 +34,8 @@ def plot_shapes(
label_key in sdata.table.obs to consider
shape_key
SpatialData shape element to consider
color_dict
dictionary of colors to use
target_coordinates
target_coordinates system of sdata object
figsize
Expand All @@ -43,9 +46,6 @@ def plot_shapes(
"""
region_key = sdata.table.uns["spatialdata_attrs"]["region"]

if group_lst is None:
group_lst = sdata.table.obs[label_obs_key].unique().tolist()

fig, axs = plt.subplots(ncols=len(shapes_lst), nrows=1, figsize=figsize)
for i in range(0, len(shapes_lst)):
poly = sdata[shape_key][sdata[shape_key].name == shapes_lst[i]].geometry.item()
Expand All @@ -60,12 +60,16 @@ def plot_shapes(
)

sdata2.pl.render_images().pl.show(ax=axs[i])
sdata2.pl.render_shapes(elements=shape_key, outline=True, fill_alpha=0.25, outline_color="red").pl.show(
ax=axs[i]
)
# sdata2.pl.render_shapes(elements=region_key, color=label_obs_key, groups=group_lst).pl.show(ax=axs[i])
# sdata2.pp.get_elements([region_key]).pl.render_shapes(color=label_obs_key, groups=group_lst).pl.show(ax=axs[i])
sdata2.pl.render_shapes(elements=region_key).pl.show(ax=axs[i])
if group_lst is None:
group_lst = sdata.table.obs[label_obs_key].unique().tolist()

if color_dict is not None:
mypal = [color_dict[x] for x in group_lst]
sdata2.pl.render_shapes(elements=region_key, color=label_obs_key, groups=group_lst, palette=mypal).pl.show(
ax=axs[i]
)
else:
sdata2.pl.render_shapes(elements=region_key, color=label_obs_key, groups=group_lst).pl.show(ax=axs[i])

axs[i].set_title(shapes_lst[i])
# if(i < len(shapes_lst)):
Expand All @@ -78,9 +82,10 @@ def plot_shape_along_axis(
sdata: sd.SpatialData,
group_lst: tuple = [], # the cell types to consider
gene_lst: tuple = [], # the genes to consider
label_obs_key: str = "celltype_spatial",
label_obs_key: str = "celltype",
sdata_group_key: str = "celltypes",
target_coordinates: str = "microns",
color_dict: tuple = [],
scale_expr: bool = False,
bin_size: int = 50,
rotation_angle: int = 0,
Expand Down Expand Up @@ -126,7 +131,8 @@ def plot_shape_along_axis(
# images=True,
# )

sdata2 = sdata_rotate(sdata, rotation_angle, target_coordinates)
sdata_rotate(sdata, rotation_angle)
sdata2 = sdata.transform_to_coordinate_system(target_coordinates)

# get elements key, probably need to do better here in the futur !!
dataset_id = sdata2.table.obs.dataset_id.unique().tolist()[0]
Expand All @@ -146,6 +152,9 @@ def plot_shape_along_axis(
cats = sdata.table.obs[label_obs_key].cat.categories.tolist()
colors = list(sdata.table.uns[label_obs_key + "_colors"])
mypal = dict(zip(cats, colors))
if color_dict is not None:
mypal = color_dict

mypal = {x: mypal[x] for x in group_lst}

# compute values dataframes
Expand All @@ -161,14 +170,14 @@ def plot_shape_along_axis(
}
vals = pd.concat([vals, pd.DataFrame([new_row])], ignore_index=True)

valct = pd.DataFrame({"microns": [], "count": [], "cell_type": []})
valct = pd.DataFrame({"microns": [], "count": [], "celltype": []})
for ct in range(0, len(group_lst)):
df2 = df_celltypes[df_celltypes["ct"] == group_lst[ct]]
for i in range(0, step_number):
new_row = {
"microns": (x_min + (i + 0.5) * bin_size),
"count": df2[(df2.x > (x_min + i * bin_size)) & (df2.x < (x_min + (i + 1) * bin_size))].shape[0],
"cell_type": group_lst[ct],
"celltype": group_lst[ct],
}
valct = pd.concat([valct, pd.DataFrame([new_row])], ignore_index=True)

Expand All @@ -183,7 +192,7 @@ def plot_shape_along_axis(
elements=sdata_polygon_key, color=label_obs_key, palette=list(mypal.values()), groups=group_lst
).pl.show(ax=ax1)

sns.lineplot(data=valct, x="microns", y="count", hue="cell_type", linewidth=0.9, palette=mypal, ax=ax2)
sns.lineplot(data=valct, x="microns", y="count", hue="celltype", linewidth=0.9, palette=mypal, ax=ax2)
ax2.get_legend().remove()

if scale_expr is True:
Expand Down Expand Up @@ -217,6 +226,7 @@ def plot_shape_along_axis(
def plot_sdata(
sdata: sd.SpatialData,
color_key: str = "celltype",
mypal: tuple = None,
):
"""Plot sdata object (i.e. embedding and polygons). This should always works if well synchronized sdata object
Expand All @@ -227,13 +237,57 @@ def plot_sdata(
color_key
color key from .table.obs
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
sc.pl.embedding(sdata.table, "umap", color=color_key, ax=ax1, show=False)
ax1.get_legend().remove()
if sdata.contains_element(sdata.table.uns["spatialdata_attrs"]["region"]):
sdata.pl.render_shapes(elements=sdata.table.uns["spatialdata_attrs"]["region"], color=color_key).pl.show(ax=ax2)
fig, ax = plt.subplots(figsize=(12, 6))
if mypal is not None:
sdata.pl.render_shapes(
elements=sdata.table.uns["spatialdata_attrs"]["region"],
color="celltype",
palette=list(mypal.values()),
groups=list(mypal.keys()),
).pl.show(ax=ax)
else:
sq.pl.spatial_scatter(sdata.table, color=color_key, shape=None, size=1, ax=ax2)
sdata.pl.render_shapes(elements=sdata.table.uns["spatialdata_attrs"]["region"], color="celltype").pl.show(ax=ax)

ax.grid(which="both", linestyle="dashed")
ax.minorticks_on()
ax.tick_params(which="minor", bottom=False, left=False)

legend_without_duplicate_labels(ax)
plt.tight_layout()


def plot_multi_sdata(
sdata: sd.SpatialData,
color_key: str = "celltype",
):
"""Plot sdata object (i.e. embedding and polygons). This should always works if well synchronized sdata object
Parameters
----------
sdata
SpatialData object.
color_key
color key from .table.obs
"""
sdata.table.obs[color_key] = sdata.table.obs[color_key].cat.remove_unused_categories()
sdata.table.uns.pop(color_key + "_colors", None)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))

sns.scatterplot(x="center_x", y="center_y", data=sdata.table.obs, s=1, hue=color_key, ax=ax1)
ax1.axis("equal")
ax1.get_legend().remove()
ax1.set_title("anndata.obs coordinates")
ax1.invert_yaxis()

sdata.pl.render_shapes(elements=sdata.table.uns["spatialdata_attrs"]["region"], color=color_key).pl.show(ax=ax2)
ax2.get_legend().remove()
ax2.set_title("spatialdata polygons")

sq.pl.spatial_scatter(sdata.table, color=color_key, shape=None, size=1, ax=ax3)
# ax2.get_legend().remove()
ax3.set_title("squidpy spatial")

plt.tight_layout()


Expand Down Expand Up @@ -292,32 +346,41 @@ def get_palette(color_key: str) -> dict:
"AlvFibro": "#d58936",
"MyoFibro": "#69140e",
}
elif color_key == "celltype2":
elif color_key == "celltype paolo":
# paolo
palette = {
# paolo
"cartilage": "#005f73",
"myeloid": "#0a9396",
"skeletal muscle": "#94d2bd",
"stromal": "#e9d8a6",
"endothelial": "#64a6bd",
"lymphatic": "#90a8c3",
"pericyte": "#d7b9d5",
"vascular": "#f4cae0",
"basal": "#003049",
"ciliated": "#d62828",
"deuterosomal": "#f77f00",
"secretory": "#fcbf49",
"excitatory neuron": "#797d62",
"neuron": "#4a4e69",
"glia": "#9b9b7a",
"olfactory sensory neuron": "#d9ae94",
"satelite": "#ffcb69",
"sustentacular": "#b58463",
# per compartment
#'cartilage nasal': '#fb8500',
#'vascular lymphatic': '#ef233c',
#'olfactory epithelium': '#344966',
#'migrating neuron': '#606c38',
"Cartilages": "#9DAF07",
"Stromal0": "#99D6A9",
"Stromal1": "#1B8F76",
"Stromal2": "#BBD870",
"Stromal3": "#4CAD4C",
"Lymphatic EC": "#F78896",
"Vascular EC": "#E788C2",
"Pericytes": "#0B4B19",
"Satellites": "#9CBBA6",
"Skeletal muscle": "#9EB3DD", # 3868A6
"Glia progenitors": "#E3D9AC",
"Olfactory ensheathing glia": "#9ecae1", # AE8C0D
"Schwann cells": "#FAF9BA", # C0AC51
"Neurons ALK+": "#7D1C53",
"Respiratory HBCs": "#D50000",
"Olfactory HBCs": "#1A237E",
"Keratinocytes": "#0000FF",
"Duct/MUC": "#BC9EDD",
"Multiciliated": "#FAF204",
"Deuterosomal": "#2EECDB",
"Sustentaculars": "#ff7f00", # E17C10
"GBCs": "#706fd3",
"Early OSNs": "#5AC2BF", # 04C9FA
"Migratory neurons": "#457b9d", ###
"VNO neurons": "#0D16C8",
"Neuron progenitors": "#61559E", # 6A0B78
"Excitatory neurons": "#95819F",
"Inhibitory neurons": "#800EF1",
"GnRH neurons": "#E00EF1",
"Myeloid": "#7E909D", # CB7647
"Microglia": "#91BFB7",
"junk neurons": "#f1faee",
}
elif color_key == "leiden": # default is 40 colors returned
l = list(range(0, 39, 1))
Expand Down Expand Up @@ -404,3 +467,19 @@ def plot_per_groups(adata, clust_key, size=60, is_spatial=False, frameon=False,
legend_loc=legend_loc,
**kwargs,
)


def legend_without_duplicate_labels(figure):
"""Remove duplicated labels in figure legend
Parameters
----------
figure
matplotlib figure.
"""
# code from here
# https://stackoverflow.com/questions/19385639/duplicate-items-in-legend-in-matplotlib

handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
figure.legend(by_label.values(), by_label.keys(), loc="center left", bbox_to_anchor=(1.05, 0.5), fontsize=6, ncol=1)
Loading

0 comments on commit 0de348b

Please sign in to comment.