Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of CINEMA-OT for pertpy #377

Merged
merged 9 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,48 @@ pred.obs["condition"] = "pred"

See [scgen tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/scgen_perturbation_prediction.html) for a more elaborate tutorial.

#### CINEMA-OT

An implementation of [CINEMA-OT](https://github.com/vandijklab/CINEMA-OT) with the ott-jax library.
CINEMA-OT is a causal framework for perturbation effect analysis to identify individual treatment effects and synergy at the single cell level.
CINEMA-OT separates confounding sources of variation from perturbation effects to obtain an optimal transport matching that reflects counterfactual cell pairs.
These cell pairs represent causal perturbation responses permitting a number of novel analyses, such as individual treatment effect analysis, response clustering, attribution analysis, and synergy analysis.
See [Causal identification of single-cell experimental perturbation effects with CINEMA-OT](https://www.biorxiv.org/content/10.1101/2022.07.31.502173v3.abstract) for more details.

```{eval-rst}
.. currentmodule:: pertpy
```

```{eval-rst}
.. autosummary::
:toctree: tools

tools.Cinemaot
```

Example implementation:

```python
import pertpy as pt

adata = pt.dt.cinemaot_example()

model = pt.tl.Cinemaot()
de = model.causaleffect(
adata,
pert_key="perturbation",
control="No stimulation",
return_matching=True,
thres=0.5,
smoothness=1e-5,
eps=1e-3,
solver="Sinkhorn",
preweight_label="cell_type0528",
)
```

See [CINEMA-OT tutorial](https://pertpy.readthedocs.io/en/latest/tutorials/notebooks/cinemaot.html) for a more elaborate tutorial.

### Perturbation space

Various modules for calculating and evaluating perturbation spaces.
Expand Down Expand Up @@ -559,3 +601,12 @@ See [perturbation space tutorial](https://pertpy.readthedocs.io/en/latest/tutori
plot.scg.reg_var_plot
plot.scg.binary_classifier
```

#### CINEMA-OT

```{eval-rst}
.. autosummary::
:toctree: plot

plot.cot.vis_matching
```
1 change: 1 addition & 0 deletions pertpy/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
bhattacherjee,
burczynski_crohn,
chang_2021,
cinemaot_example,
datlinger_2017,
datlinger_2021,
dialogue_example,
Expand Down
22 changes: 22 additions & 0 deletions pertpy/data/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,3 +1407,25 @@ def zhao_2021() -> AnnData: # pragma: no cover
adata = sc.read_h5ad(output_file_path)

return adata


def cinemaot_example() -> AnnData: # pragma: no cover:
"""CINEMA-OT Example dataset.

Ex vivo stimulation of human peripheral blood mononuclear cells (PBMC) with interferon.

Returns:
:class:`~anndata.AnnData` object of PBMCs stimulated with interferon.
"""
output_file_name = "cinemaot_example.h5ad"
output_file_path = settings.datasetdir.__str__() + "/" + output_file_name
if not Path(output_file_path).exists():
_download(
url="https://figshare.com/ndownloader/files/42362796?private_link=270b0d2c7f1ea57c366d",
output_file_name=output_file_name,
output_path=settings.datasetdir,
is_zip=False,
)
adata = sc.read_h5ad(output_file_path)

return adata
1 change: 1 addition & 0 deletions pertpy/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
except ImportError:
pass

from pertpy.plot._cinemaot import CinemaotPlot as cot
from pertpy.plot._guide_rna import GuideRnaPlot as guide
from pertpy.plot._milopy import MilopyPlot as milo
from pertpy.plot._mixscape import MixscapePlot as ms
Expand Down
81 changes: 81 additions & 0 deletions pertpy/plot/_cinemaot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Optional

import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import seaborn as sns
from anndata import AnnData
from matplotlib.axes import Axes
from scanpy.plotting import _utils


class CinemaotPlot:
"""Plotting functions for CINEMA-OT. Only includes new functions beyond the scanpy.pl.embedding family."""

@staticmethod
def vis_matching(
adata: AnnData,
de: AnnData,
pert_key: str,
control: str,
de_label: str,
source_label: str,
matching_rep: str = "ot",
resolution: float = 0.5,
normalize: str = "col",
title: str = "CINEMA-OT matching matrix",
min_val: float = 0.01,
show: bool = True,
save: Optional[str] = None,
ax: Optional[Axes] = None,
**kwargs,
) -> None:
"""Visualize the CINEMA-OT matching matrix.

Args:
adata: the original anndata after running cinemaot.causaleffect or cinemaot.causaleffect_weighted.
de: The anndata output from Cinemaot.causaleffect() or Cinemaot.causaleffect_weighted().
pert_key: The column of `.obs` with perturbation categories, should also contain `control`.
control: Control category from the `pert_key` column.
de_label: the label for differential response. If none, use leiden cluster labels at resolution 1.0.
source_label: the confounder / cell type label.
matching_rep: the place that stores the matching matrix. default de.obsm['ot'].
normalize: normalize the coarse-grained matching matrix by row / column.
title: the title for the figure.
min_val: The min value to truncate the matching matrix.
show: Show the plot, do not return axis.
save: If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {`'.pdf'`, `'.png'`, `'.svg'`}.
**kwargs: Other parameters to input for seaborn.heatmap.
"""
adata_ = adata[adata.obs[pert_key] == control]

df = pd.DataFrame(de.obsm[matching_rep])
if de_label is None:
de_label = "leiden"
sc.pp.neighbors(de, use_rep="X_embedding")
sc.tl.leiden(de, resolution=resolution)
df["de_label"] = de.obs[de_label].astype(str).values
df["de_label"] = "Response " + df["de_label"]
df = df.groupby("de_label").sum().T
df["source_label"] = adata_.obs[source_label].astype(str).values
df = df.groupby("source_label").sum()

if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T
df = df.clip(lower=min_val) - min_val
if normalize == "col":
df = df / df.sum(axis=0)
else:
df = (df.T / df.sum(axis=1)).T

g = sns.heatmap(df, annot=True, ax=ax, **kwargs)
plt.title(title)
_utils.savefig_or_show("matching_heatmap", show=show, save=save)
if not show:
if ax is not None:
return ax
else:
return g
1 change: 1 addition & 0 deletions pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from rich import print

from pertpy.tools._augur import Augur
from pertpy.tools._cinemaot import Cinemaot
from pertpy.tools._dialogue import Dialogue
from pertpy.tools._differential_gene_expression import DifferentialGeneExpression
from pertpy.tools._distances._distance_tests import DistanceTest
Expand Down
Loading
Loading