From c499e877d0cbedadd24820fa43c45f49a37f604a Mon Sep 17 00:00:00 2001 From: cobioda Date: Wed, 18 Dec 2024 10:31:09 +0100 Subject: [PATCH] v0.1 --- pyproject.toml | 11 ++- src/scispy/pl/basic.py | 1 - src/scispy/pp/basic.py | 14 +-- src/scispy/tl/__init__.py | 2 + src/scispy/tl/basic.py | 196 +++++++++++++++++++++++++++----------- 5 files changed, 154 insertions(+), 70 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b6260f..eea8434 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,15 +19,16 @@ urls.Documentation = "https://scispy.readthedocs.io/" urls.Source = "https://github.com/cobioda/scispy" urls.Home-page = "https://github.com/cobioda/scispy" dependencies = [ - "scvi-tools", + #"scvi-tools", "pydeseq2", "decoupler", - "spatialdata", - "spatialdata-io", - "spatialdata-plot", + #"spatialdata", + #"spatialdata-io", + #"spatialdata-plot", #"napari-spatialdata", - "squidpy", + #"squidpy", "adjustText", + "statannotations", #"torch==1.13.1", # for debug logging (referenced from the issue template) diff --git a/src/scispy/pl/basic.py b/src/scispy/pl/basic.py index ed59038..8213ee9 100644 --- a/src/scispy/pl/basic.py +++ b/src/scispy/pl/basic.py @@ -398,7 +398,6 @@ def get_palette(color_key: str) -> dict: "B": "#fffbbd", "Megak": "#006400", "Pericyte": "#9c6644", - "Pericytes": "#9c6644", "SMC": "#d81159", "AdvFibro": "#ef6351", "AlvFibro": "#d58936", diff --git a/src/scispy/pp/basic.py b/src/scispy/pp/basic.py index 22bb9d7..223fe14 100644 --- a/src/scispy/pp/basic.py +++ b/src/scispy/pp/basic.py @@ -73,6 +73,7 @@ def scvi_annotate( label_ref: str = "celltype", label_key: str = "celltype", layer: str = "counts", + batch_size: int = 128, metaref2add: tuple = [], filter_under_score: float = 0.5, ): @@ -128,7 +129,7 @@ def scvi_annotate( scvi.model.SCVI.setup_anndata(concat, layer=layer, batch_key="tech") vae = scvi.model.SCVI(concat) # Train the model - vae.train() + vae.train(batch_size=batch_size) # Register the object and run scANVI scvi.model.SCANVI.setup_anndata( @@ -140,7 +141,7 @@ def scvi_annotate( ) lvae = scvi.model.SCANVI.from_scvi_model(vae, labels_key=label_key, unlabeled_category="nan", adata=concat) - lvae.train(max_epochs=20, n_samples_per_label=100) + lvae.train(max_epochs=20, n_samples_per_label=100, batch_size=batch_size) concat.obs["C_scANVI"] = lvae.predict(concat) concat.obsm["X_scANVI"] = lvae.get_latent_representation(concat) @@ -152,7 +153,6 @@ def scvi_annotate( merfish_mask = concat.obs["tech"] == "MERFISH" ad_spatial.obs[f"{label_key}"] = concat.obs["C_scANVI"][merfish_mask].values ad_spatial.obs[f"{label_key}_score"] = concat.obs["score"][merfish_mask].values - ad_spatial.obs[f"{label_key}"] = ad_spatial.obs[f"{label_key}"].astype("category") for i in range(0, len(metaref2add)): @@ -161,10 +161,10 @@ def scvi_annotate( ad_spatial.obs[f"{metaref2add[i]}"] = ad_spatial.obs[f"{metaref2add[i]}"].astype("category") # remove cells having a bad score - nb_cells = ad_spatial.shape[0] - ad_spatial = ad_spatial[ad_spatial.obs[f"{label_key}_score"] >= filter_under_score] - filtered_cells = nb_cells - ad_spatial.shape[0] - print("low assignment score filtering ", filtered_cells) + # nb_cells = ad_spatial.shape[0] + # ad_spatial = ad_spatial[ad_spatial.obs[f"{label_key}_score"] >= filter_under_score] + # filtered_cells = nb_cells - ad_spatial.shape[0] + # print("low assignment score filtering ", filtered_cells) def sync_shape( diff --git a/src/scispy/tl/__init__.py b/src/scispy/tl/__init__.py index 45ce878..4f3d591 100644 --- a/src/scispy/tl/__init__.py +++ b/src/scispy/tl/__init__.py @@ -5,6 +5,7 @@ get_sdata_polygon, prep_pseudobulk, run_pseudobulk, + scis_prop, sdata_querybox, sdata_rotate, ) @@ -18,4 +19,5 @@ "run_pseudobulk", "sdata_rotate", "sdata_querybox", + "scis_prop", ] diff --git a/src/scispy/tl/basic.py b/src/scispy/tl/basic.py index d68cba7..011cd57 100644 --- a/src/scispy/tl/basic.py +++ b/src/scispy/tl/basic.py @@ -18,6 +18,7 @@ from spatialdata import SpatialData from spatialdata.models import PointsModel, ShapesModel from spatialdata.transformations import Affine, Identity, Translation, set_transformation +from statannotations.Annotator import Annotator def add_shapes_from_hdf5( @@ -271,10 +272,10 @@ def prep_pseudobulk( def run_pseudobulk( adata: an.AnnData, - pseudotype_1: str, - pseudotype_2: str, - pseudotype_key: str = "pseudotype", - pseudoname_key: str = "pseudoname", + cond_1: str, + cond_2: str, + cond_key: str, + replicate_key: str, groups_key: str = "celltype", groups: tuple = [], layer: str = "counts", @@ -292,14 +293,14 @@ def run_pseudobulk( ---------- adata AnnData object. - pseudotype_1 - pseudobulk condition (pseudotype) 1 - pseudotype_2 - pseudobulk condition (pseudotype) 2 - pseudotype_key - sdata.table.obs key, i.e. condition - pseudoname_key - sdata.table.obs key, i.e. replicate name + cond_1 + pseudobulk cond_1 + cond_2 + pseudobulk cond_2 + cond_key + sdata.table.obs cond_key + replicate_key + sdata.table.obs replicate_key groups_key sdata.table.obs key, i.e. cell types groups @@ -328,11 +329,11 @@ def run_pseudobulk( # https://decoupler-py.readthedocs.io/en/latest/notebooks/pseudobulk.html # sns.set(font_scale=0.5) - adata = adata[adata.obs[pseudotype_key].isin([pseudotype_1, pseudotype_2])].copy() + adata = adata[adata.obs[cond_key].isin([cond_1, cond_2])].copy() pdata = dc.get_pseudobulk( adata, - sample_col=pseudoname_key, # "pseudoname" + sample_col=replicate_key, # "pseudoname" groups_col=groups_key, # celltype layer=layer, mode="sum", @@ -348,69 +349,64 @@ def run_pseudobulk( for ct in groups: sub = pdata[pdata.obs[groups_key] == ct].copy() - if len(sub.obs[pseudotype_key].to_list()) > 1: + if len(sub.obs[cond_key].to_list()) > 1: # Obtain genes that pass the thresholds - genes = dc.filter_by_expr(sub, group=pseudotype_key, min_count=5, min_total_count=5) + genes = dc.filter_by_expr(sub, group=cond_key, min_count=5, min_total_count=5) # Filter by these genes sub = sub[:, genes].copy() - if len(sub.obs[pseudotype_key].unique().tolist()) > 1: + if len(sub.obs[cond_key].unique().tolist()) > 1: # Build DESeq2 object dds = DeseqDataSet( adata=sub, - design_factors=pseudotype_key, - ref_level=[pseudotype_key, pseudotype_1], + design_factors=cond_key, + ref_level=[cond_key, cond_1], refit_cooks=True, quiet=True, ) - dds.deseq2() - stat_res = DeseqStats(dds, contrast=[pseudotype_key, pseudotype_1, pseudotype_2], quiet=True) + print(len(sub.obs[replicate_key].unique().tolist())) + if len(sub.obs[replicate_key].unique().tolist()) > 2: + dds.deseq2() + stat_res = DeseqStats(dds, contrast=[cond_key, cond_1, cond_2], quiet=True) - stat_res.summary() - coeff_str = pseudotype_key + "_" + pseudotype_2 + "_vs_" + pseudotype_1 - stat_res.lfc_shrink(coeff=coeff_str) + stat_res.summary() + coeff_str = cond_key + "_" + cond_2 + "_vs_" + cond_1 + stat_res.lfc_shrink(coeff=coeff_str) - results_df = stat_res.results_df + results_df = stat_res.results_df - fig, axs = plt.subplots(1, 2, figsize=figsize) - dc.plot_volcano_df(results_df, x="log2FoldChange", y="padj", ax=axs[0], top=20) - axs[0].set_title(ct) + fig, axs = plt.subplots(1, 2, figsize=figsize) + dc.plot_volcano_df(results_df, x="log2FoldChange", y="padj", ax=axs[0], top=20) + axs[0].set_title(ct) - # sign_thr=0.05, lFCs_thr=0.5 - results_df["pvals"] = -np.log10(results_df["padj"]) + # sign_thr=0.05, lFCs_thr=0.5 + results_df["pvals"] = -np.log10(results_df["padj"]) - up_msk = (results_df["log2FoldChange"] >= lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr)) - dw_msk = (results_df["log2FoldChange"] <= -lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr)) - signs = results_df[up_msk | dw_msk].sort_values("pvals", ascending=False) - signs = signs.iloc[:20] - signs = signs.sort_values("log2FoldChange", ascending=False) + up_msk = (results_df["log2FoldChange"] >= lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr)) + dw_msk = (results_df["log2FoldChange"] <= -lFCs_thr) & (results_df["pvals"] >= -np.log10(sign_thr)) + signs = results_df[up_msk | dw_msk].sort_values("pvals", ascending=False) + signs = signs.iloc[:20] + signs = signs.sort_values("log2FoldChange", ascending=False) - # concatenate to total - signs[groups_key] = ct - df_total = pd.concat([df_total, signs.reset_index()]) + # concatenate to total + signs[groups_key] = ct + df_total = pd.concat([df_total, signs.reset_index()]) - if len(signs.index.tolist()) > 0: - sc.pp.normalize_total(sub) - sc.pp.log1p(sub) - sc.pp.scale(sub, max_value=10) - sc.pl.matrixplot(sub, signs.index, groupby=pseudoname_key, ax=axs[1]) + if len(signs.index.tolist()) > 0: + sc.pp.normalize_total(sub) + sc.pp.log1p(sub) + sc.pp.scale(sub, max_value=10) + sc.pl.matrixplot(sub, signs.index, groupby=replicate_key, ax=axs[1]) - plt.tight_layout() + plt.tight_layout() - if save is True: - results_df.to_csv(save_prefix + "_" + ct + ".csv") - fig.savefig(save_prefix + "_" + ct + ".pdf", bbox_inches="tight") + if save is True: + results_df.to_csv(save_prefix + "_" + ct + ".csv") + fig.savefig(save_prefix + "_" + ct + ".pdf", bbox_inches="tight") - if len(df_total[groups_key].unique()) > 2: - pivlfc = pd.pivot_table( - df_total, values=["log2FoldChange"], index=["index"], columns=[groups_key], fill_value=0 - ) - # pd.pivot_table(df_total, values=["pvals"], index=["index"], columns=[groups_col], fill_value=0) - # ## plot pivot table as heatmap using seaborn - sns.clustermap(pivlfc, cmap="vlag", figsize=(6, 6)) - # ## plt.setp( ax.xaxis.get_majorticklabels(), rotation=90) - # # plt.tight_layout() - # # plt.show() + if not df_total.empty: + if len(df_total[groups_key].unique()) > 2: + pd.pivot_table(df_total, values=["log2FoldChange"], index=["index"], columns=[groups_key], fill_value=0) return df_total @@ -576,3 +572,89 @@ def sdata_querybox( else: return sdata_crop + + +def scis_prop( + adata: an.AnnData, + celltype: str = "ct_musk", + zone: str = "cc_niches", + replicate: str = "sample", + condition: str = "group", + condition_order: tuple = ["CTRL", "PAH"], # might be possible to provide more conditions + top: int = 5, + figsize: tuple = (6, 3), +): + """Compute per zone celltype proportion between 2 conditions using replicate for statistical testing + + Parameters + ---------- + adata + AnnData object. + celltype + celltype key in adata.obs + zone + zone key in adata.obs + replicate + replicate key in adata.obs + condition + condition key in adata.obs + condition_order + tuple of the 2 conditions to test + top + top celltype to consider + figsize + figure size + Returns + ------- + + """ + sns.set_theme(style="whitegrid", palette="pastel") + l = list(adata.obs[zone].unique()) + for n in l: + # print(n) + df = adata[adata.obs[zone] == n].obs[[replicate, condition, celltype]] + df2 = df.groupby([replicate, condition, celltype])[celltype].count().unstack() + df2 = df2.div(df2.sum(axis=1), axis=0).reset_index() + df2 = df2.melt(id_vars=[replicate, condition]) + df2 = df2.dropna() + df2 = df2[df2.value > 0] + + hits = list(df[celltype].value_counts().head(top).keys()) + df2 = df2[df2[celltype].isin(hits)] + + pairs = [] + for ct in hits: + if len(df2[df2[celltype] == ct][condition].unique()) > 1: + pairs.append([(ct, condition_order[0]), (ct, condition_order[1])]) + + subcat_order = hits + hue_plot_params = { + "data": df2, + "x": celltype, + "y": "value", + "order": subcat_order, + "hue": "group", + "hue_order": condition_order, + # "palette": pal_group, + } + + if len(pairs) > 0: + fig, ax = plt.subplots(1, 1, figsize=figsize) + sns.boxplot(ax=ax, **hue_plot_params, boxprops={"alpha": 0.8}, showfliers=False, linewidth=0.5) + sns.stripplot(ax=ax, **hue_plot_params, dodge=True, edgecolor="black", linewidth=0.5, size=3) + + annotator = Annotator(ax, pairs, **hue_plot_params) + annotator.configure(test="Mann-Whitney", text_format="star") + annotator.apply_and_annotate() + + # When creating the legend, only use the first 2 elements + handles, labels = ax.get_legend_handles_labels() + l = plt.legend(handles[0:2], labels[0:2], bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0) + + ax.set_xticklabels(ax.get_xticklabels(), rotation=90, size=6) + ax.set_yticklabels(ax.get_yticklabels(), size=6) + ax.xaxis.grid(True) + ax.yaxis.grid(True) + ax.set(ylabel="") + ax.set_title("zone " + str(n)) + plt.tight_layout()