Skip to content

Commit

Permalink
Merge pull request #4 from cobioda/bego
Browse files Browse the repository at this point in the history
v0.1
  • Loading branch information
cobioda authored Dec 18, 2024
2 parents 71dc315 + 00ebc36 commit 98e031c
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 66 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ urls.Documentation = "https://scispy.readthedocs.io/"
urls.Source = "https://github.com/cobioda/scispy"
urls.Home-page = "https://github.com/cobioda/scispy"
dependencies = [
"scvi-tools",
#"scvi-tools",
"pydeseq2",
"decoupler",
"squidpy",
"spatialdata",
"spatialdata-io",
"spatialdata-plot",
#"napari-spatialdata",
"statannotations",
#"torch==1.13.1",
#"adjustText",

Expand Down
1 change: 0 additions & 1 deletion src/scispy/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@ def get_palette(color_key: str) -> dict:
"B": "#fffbbd",
"Megak": "#006400",
"Pericyte": "#9c6644",
"Pericytes": "#9c6644",
"SMC": "#d81159",
"AdvFibro": "#ef6351",
"AlvFibro": "#d58936",
Expand Down
14 changes: 7 additions & 7 deletions src/scispy/pp/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def scvi_annotate(
label_ref: str = "celltype",
label_key: str = "celltype",
layer: str = "counts",
batch_size: int = 128,
metaref2add: tuple = [],
filter_under_score: float = 0.5,
):
Expand Down Expand Up @@ -128,7 +129,7 @@ def scvi_annotate(
scvi.model.SCVI.setup_anndata(concat, layer=layer, batch_key="tech")
vae = scvi.model.SCVI(concat)
# Train the model
vae.train()
vae.train(batch_size=batch_size)

# Register the object and run scANVI
scvi.model.SCANVI.setup_anndata(
Expand All @@ -140,7 +141,7 @@ def scvi_annotate(
)

lvae = scvi.model.SCANVI.from_scvi_model(vae, labels_key=label_key, unlabeled_category="nan", adata=concat)
lvae.train(max_epochs=20, n_samples_per_label=100)
lvae.train(max_epochs=20, n_samples_per_label=100, batch_size=batch_size)

concat.obs["C_scANVI"] = lvae.predict(concat)
concat.obsm["X_scANVI"] = lvae.get_latent_representation(concat)
Expand All @@ -152,7 +153,6 @@ def scvi_annotate(
merfish_mask = concat.obs["tech"] == "MERFISH"
ad_spatial.obs[f"{label_key}"] = concat.obs["C_scANVI"][merfish_mask].values
ad_spatial.obs[f"{label_key}_score"] = concat.obs["score"][merfish_mask].values

ad_spatial.obs[f"{label_key}"] = ad_spatial.obs[f"{label_key}"].astype("category")

for i in range(0, len(metaref2add)):
Expand All @@ -161,10 +161,10 @@ def scvi_annotate(
ad_spatial.obs[f"{metaref2add[i]}"] = ad_spatial.obs[f"{metaref2add[i]}"].astype("category")

# remove cells having a bad score
nb_cells = ad_spatial.shape[0]
ad_spatial = ad_spatial[ad_spatial.obs[f"{label_key}_score"] >= filter_under_score]
filtered_cells = nb_cells - ad_spatial.shape[0]
print("low assignment score filtering ", filtered_cells)
# nb_cells = ad_spatial.shape[0]
# ad_spatial = ad_spatial[ad_spatial.obs[f"{label_key}_score"] >= filter_under_score]
# filtered_cells = nb_cells - ad_spatial.shape[0]
# print("low assignment score filtering ", filtered_cells)


def sync_shape(
Expand Down
2 changes: 2 additions & 0 deletions src/scispy/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
get_sdata_polygon,
prep_pseudobulk,
run_pseudobulk,
scis_prop,
sdata_querybox,
sdata_rotate,
)
Expand All @@ -23,5 +24,6 @@
"run_pseudobulk",
"sdata_rotate",
"sdata_querybox",
"scis_prop",
"shapes_of_cell_type",
]
196 changes: 139 additions & 57 deletions src/scispy/tl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from spatialdata import SpatialData
from spatialdata.models import PointsModel, ShapesModel
from spatialdata.transformations import Affine, Identity, Translation, set_transformation
from statannotations.Annotator import Annotator


def add_shapes_from_hdf5(
Expand Down Expand Up @@ -271,10 +272,10 @@ def prep_pseudobulk(

def run_pseudobulk(
adata: an.AnnData,
pseudotype_1: str,
pseudotype_2: str,
pseudotype_key: str = "pseudotype",
pseudoname_key: str = "pseudoname",
cond_1: str,
cond_2: str,
cond_key: str,
replicate_key: str,
groups_key: str = "celltype",
groups: tuple = [],
layer: str = "counts",
Expand All @@ -292,14 +293,14 @@ def run_pseudobulk(
----------
adata
AnnData object.
pseudotype_1
pseudobulk condition (pseudotype) 1
pseudotype_2
pseudobulk condition (pseudotype) 2
pseudotype_key
sdata.table.obs key, i.e. condition
pseudoname_key
sdata.table.obs key, i.e. replicate name
cond_1
pseudobulk cond_1
cond_2
pseudobulk cond_2
cond_key
sdata.table.obs cond_key
replicate_key
sdata.table.obs replicate_key
groups_key
sdata.table.obs key, i.e. cell types
groups
Expand Down Expand Up @@ -328,11 +329,11 @@ def run_pseudobulk(
# https://decoupler-py.readthedocs.io/en/latest/notebooks/pseudobulk.html
# sns.set(font_scale=0.5)

adata = adata[adata.obs[pseudotype_key].isin([pseudotype_1, pseudotype_2])].copy()
adata = adata[adata.obs[cond_key].isin([cond_1, cond_2])].copy()

pdata = dc.get_pseudobulk(
adata,
sample_col=pseudoname_key, # "pseudoname"
sample_col=replicate_key, # "pseudoname"
groups_col=groups_key, # celltype
layer=layer,
mode="sum",
Expand All @@ -348,69 +349,64 @@ def run_pseudobulk(
for ct in groups:
sub = pdata[pdata.obs[groups_key] == ct].copy()

if len(sub.obs[pseudotype_key].to_list()) > 1:
if len(sub.obs[cond_key].to_list()) > 1:
# Obtain genes that pass the thresholds
genes = dc.filter_by_expr(sub, group=pseudotype_key, min_count=5, min_total_count=5)
genes = dc.filter_by_expr(sub, group=cond_key, min_count=5, min_total_count=5)
# Filter by these genes
sub = sub[:, genes].copy()

if len(sub.obs[pseudotype_key].unique().tolist()) > 1:
if len(sub.obs[cond_key].unique().tolist()) > 1:
# Build DESeq2 object
dds = DeseqDataSet(
adata=sub,
design_factors=pseudotype_key,
ref_level=[pseudotype_key, pseudotype_1],
design_factors=cond_key,
ref_level=[cond_key, cond_1],
refit_cooks=True,
quiet=True,
)
dds.deseq2()
stat_res = DeseqStats(dds, contrast=[pseudotype_key, pseudotype_1, pseudotype_2], quiet=True)
print(len(sub.obs[replicate_key].unique().tolist()))
if len(sub.obs[replicate_key].unique().tolist()) > 2:
dds.deseq2()
stat_res = DeseqStats(dds, contrast=[cond_key, cond_1, cond_2], quiet=True)

stat_res.summary()
coeff_str = pseudotype_key + "_" + pseudotype_2 + "_vs_" + pseudotype_1
stat_res.lfc_shrink(coeff=coeff_str)
stat_res.summary()
coeff_str = cond_key + "_" + cond_2 + "_vs_" + cond_1
stat_res.lfc_shrink(coeff=coeff_str)

results_df = stat_res.results_df
results_df = stat_res.results_df

fig, axs = plt.subplots(1, 2, figsize=figsize)
dc.plot_volcano_df(results_df, x="log2FoldChange", y="padj", ax=axs[0], top=20)
axs[0].set_title(ct)
fig, axs = plt.subplots(1, 2, figsize=figsize)
dc.plot_volcano_df(results_df, x="log2FoldChange", y="padj", ax=axs[0], top=20)
axs[0].set_title(ct)

# sign_thr=0.05, lFCs_thr=0.5
results_df["pvals"] = -np.log10(results_df["padj"])
# sign_thr=0.05, lFCs_thr=0.5
results_df["pvals"] = -np.log10(results_df["padj"])

up_msk = (results_df["log2FoldChange"] >= lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr))
dw_msk = (results_df["log2FoldChange"] <= -lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr))
signs = results_df[up_msk | dw_msk].sort_values("pvals", ascending=False)
signs = signs.iloc[:20]
signs = signs.sort_values("log2FoldChange", ascending=False)
up_msk = (results_df["log2FoldChange"] >= lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr))
dw_msk = (results_df["log2FoldChange"] <= -lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr))
signs = results_df[up_msk | dw_msk].sort_values("pvals", ascending=False)
signs = signs.iloc[:20]
signs = signs.sort_values("log2FoldChange", ascending=False)

# concatenate to total
signs[groups_key] = ct
df_total = pd.concat([df_total, signs.reset_index()])
# concatenate to total
signs[groups_key] = ct
df_total = pd.concat([df_total, signs.reset_index()])

if len(signs.index.tolist()) > 0:
sc.pp.normalize_total(sub)
sc.pp.log1p(sub)
sc.pp.scale(sub, max_value=10)
sc.pl.matrixplot(sub, signs.index, groupby=pseudoname_key, ax=axs[1])
if len(signs.index.tolist()) > 0:
sc.pp.normalize_total(sub)
sc.pp.log1p(sub)
sc.pp.scale(sub, max_value=10)
sc.pl.matrixplot(sub, signs.index, groupby=replicate_key, ax=axs[1])

plt.tight_layout()
plt.tight_layout()

if save is True:
results_df.to_csv(save_prefix + "_" + ct + ".csv")
fig.savefig(save_prefix + "_" + ct + ".pdf", bbox_inches="tight")
if save is True:
results_df.to_csv(save_prefix + "_" + ct + ".csv")
fig.savefig(save_prefix + "_" + ct + ".pdf", bbox_inches="tight")

if len(df_total[groups_key].unique()) > 2:
pivlfc = pd.pivot_table(
df_total, values=["log2FoldChange"], index=["index"], columns=[groups_key], fill_value=0
)
# pd.pivot_table(df_total, values=["pvals"], index=["index"], columns=[groups_col], fill_value=0)
# ## plot pivot table as heatmap using seaborn
sns.clustermap(pivlfc, cmap="vlag", figsize=(6, 6))
# ## plt.setp( ax.xaxis.get_majorticklabels(), rotation=90)
# # plt.tight_layout()
# # plt.show()
if not df_total.empty:
if len(df_total[groups_key].unique()) > 2:
pd.pivot_table(df_total, values=["log2FoldChange"], index=["index"], columns=[groups_key], fill_value=0)

return df_total

Expand Down Expand Up @@ -576,3 +572,89 @@ def sdata_querybox(

else:
return sdata_crop


def scis_prop(
adata: an.AnnData,
celltype: str = "ct_musk",
zone: str = "cc_niches",
replicate: str = "sample",
condition: str = "group",
condition_order: tuple = ["CTRL", "PAH"], # might be possible to provide more conditions
top: int = 5,
figsize: tuple = (6, 3),
):
"""Compute per zone celltype proportion between 2 conditions using replicate for statistical testing
Parameters
----------
adata
AnnData object.
celltype
celltype key in adata.obs
zone
zone key in adata.obs
replicate
replicate key in adata.obs
condition
condition key in adata.obs
condition_order
tuple of the 2 conditions to test
top
top celltype to consider
figsize
figure size
Returns
-------
"""
sns.set_theme(style="whitegrid", palette="pastel")
l = list(adata.obs[zone].unique())
for n in l:
# print(n)
df = adata[adata.obs[zone] == n].obs[[replicate, condition, celltype]]
df2 = df.groupby([replicate, condition, celltype])[celltype].count().unstack()
df2 = df2.div(df2.sum(axis=1), axis=0).reset_index()
df2 = df2.melt(id_vars=[replicate, condition])
df2 = df2.dropna()
df2 = df2[df2.value > 0]

hits = list(df[celltype].value_counts().head(top).keys())
df2 = df2[df2[celltype].isin(hits)]

pairs = []
for ct in hits:
if len(df2[df2[celltype] == ct][condition].unique()) > 1:
pairs.append([(ct, condition_order[0]), (ct, condition_order[1])])

subcat_order = hits
hue_plot_params = {
"data": df2,
"x": celltype,
"y": "value",
"order": subcat_order,
"hue": "group",
"hue_order": condition_order,
# "palette": pal_group,
}

if len(pairs) > 0:
fig, ax = plt.subplots(1, 1, figsize=figsize)
sns.boxplot(ax=ax, **hue_plot_params, boxprops={"alpha": 0.8}, showfliers=False, linewidth=0.5)
sns.stripplot(ax=ax, **hue_plot_params, dodge=True, edgecolor="black", linewidth=0.5, size=3)

annotator = Annotator(ax, pairs, **hue_plot_params)
annotator.configure(test="Mann-Whitney", text_format="star")
annotator.apply_and_annotate()

# When creating the legend, only use the first 2 elements
handles, labels = ax.get_legend_handles_labels()
l = plt.legend(handles[0:2], labels[0:2], bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

ax.set_xticklabels(ax.get_xticklabels(), rotation=90, size=6)
ax.set_yticklabels(ax.get_yticklabels(), size=6)
ax.xaxis.grid(True)
ax.yaxis.grid(True)
ax.set(ylabel="")
ax.set_title("zone " + str(n))
plt.tight_layout()

0 comments on commit 98e031c

Please sign in to comment.