Skip to content

Commit

Permalink
Merge pull request #560 from theislab/feature/regression_classifier
Browse files Browse the repository at this point in the history
Logistic regression support for the Discriminator Classifier
  • Loading branch information
Lilly-May authored Mar 19, 2024
2 parents 3ea2993 + 7fd7bbe commit 80ef0f0
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 32 deletions.
3 changes: 2 additions & 1 deletion docs/usage/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ Pertpy offers various modules for calculating and evaluating perturbation spaces
.. autosummary::
:toctree: tools
tools.DiscriminatorClassifierSpace
tools.MLPClassifierSpace
tools.LRClassifierSpace
tools.CentroidSpace
tools.DBSCANSpace
tools.KMeansSpace
Expand Down
6 changes: 5 additions & 1 deletion pertpy/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from pertpy.tools._milo import Milo
from pertpy.tools._mixscape import Mixscape
from pertpy.tools._perturbation_space._clustering import ClusteringSpace
from pertpy.tools._perturbation_space._discriminator_classifier import DiscriminatorClassifierSpace
from pertpy.tools._perturbation_space._discriminator_classifiers import (
DiscriminatorClassifierSpace,
LRClassifierSpace,
MLPClassifierSpace,
)
from pertpy.tools._perturbation_space._simple import CentroidSpace, DBSCANSpace, KMeansSpace, PseudobulkSpace
from pertpy.tools._scgen import SCGEN
2 changes: 1 addition & 1 deletion pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def onesided_distances(
n_jobs: int = -1,
**kwargs,
) -> pd.DataFrame:
"""Get pairwise distances between groups of cells.
"""Get distances between one selected cell group and the remaining other cell groups.
Args:
adata: Annotated data matrix.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import warnings
from typing import TYPE_CHECKING, Literal

import anndata
import numpy as np
Expand All @@ -10,6 +11,7 @@
import torch
from anndata import AnnData
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from torch import optim
Expand All @@ -18,8 +20,118 @@
from pertpy.tools._perturbation_space._perturbation_space import PerturbationSpace


class DiscriminatorClassifierSpace(PerturbationSpace):
"""Leveraging discriminator classifier. Fit a regressor model to the data and take the feature space.
class LRClassifierSpace(PerturbationSpace):
"""Fits a logistic regression model to the data and takes the feature space as embedding.
We fit one logistic regression model per perturbation. After training, the coefficients of the logistic regression
model are used as the feature space. This results in one embedding per perturbation.
"""

def compute(
self,
adata: AnnData,
target_col: str = "perturbations",
layer_key: str = None,
embedding_key: str = None,
test_split_size: float = 0.2,
max_iter: int = 1000,
):
"""
Fits a logistic regression model to the data and takes the coefficients of the logistic regression
model as perturbation embedding.
Args:
adata: AnnData object of size cells x genes
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
layer_key: Layer in adata to use. Defaults to None.
embedding_key: Key of the embedding in obsm to be used as data for the logistic regression classifier.
Can only be specified if layer_key is None. Defaults to None.
test_split_size: Fraction of data to put in the test set. Default to 0.2.
max_iter: Maximum number of iterations taken for the solvers to converge. Defaults to 1000.
Returns:
AnnData object with the logistic regression coefficients as the embedding in X and the perturbations as .obs['perturbations'].
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.norman_2019()
>>> rcs = pt.tl.LRClassifierSpace()
>>> pert_embeddings = rcs.compute(adata, embedding_key="X_pca", target_col="perturbation_name")
"""
if layer_key is not None and layer_key not in adata.obs.columns:
raise ValueError(f"Layer key {layer_key} not found in adata.")

if embedding_key is not None and embedding_key not in adata.obsm.keys():
raise ValueError(f"Embedding key {embedding_key} not found in adata.obsm.")

if layer_key is not None and embedding_key is not None:
raise ValueError("Cannot specify both layer_key and embedding_key.")

if target_col not in adata.obs:
raise ValueError(f"Column {target_col!r} does not exist in the .obs attribute.")

if layer_key is not None:
regression_data = adata.layers[layer_key]
elif embedding_key is not None:
regression_data = adata.obsm[embedding_key]
else:
regression_data = adata.X

regression_labels = adata.obs[target_col]

# Save adata observations for embedding annotations in get_embeddings
adata_obs = adata.obs.reset_index(drop=True)
adata_obs = adata_obs.groupby(target_col).agg(
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
)

# Fit a logistic regression model for each perturbation
regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")
regression_embeddings = {}
regression_scores = {}

for perturbation in regression_labels.unique():
labels = np.where(regression_labels == perturbation, 1, 0)
X_train, X_test, y_train, y_test = train_test_split(
regression_data, labels, test_size=test_split_size, stratify=labels
)

regression_model.fit(X_train, y_train)
regression_embeddings[perturbation] = regression_model.coef_
regression_scores[perturbation] = regression_model.score(X_test, y_test)

# Save the regression embeddings and scores in an AnnData object
pert_adata = AnnData(X=np.array(list(regression_embeddings.values())).squeeze())
pert_adata.obs["perturbations"] = list(regression_embeddings.keys())
pert_adata.obs["classifier_score"] = list(regression_scores.values())

# Save adata observations for embedding annotations
for obs_name in adata_obs.columns:
if not adata_obs[obs_name].isnull().values.any():
pert_adata.obs[obs_name] = pert_adata.obs["perturbations"].map(
{pert: adata_obs.loc[pert][obs_name] for pert in adata_obs.index}
)

return pert_adata


# Ensure backward compatibility with DiscriminatorClassifierSpace
def DiscriminatorClassifierSpace():
warnings.warn(
"The DiscriminatorClassifierSpace class is deprecated and will be removed in the future."
"Please use the MLPClassifierSpace or the LRClassifierSpace class instead.",
DeprecationWarning,
stacklevel=2,
)

return MLPClassifierSpace()


class MLPClassifierSpace(PerturbationSpace):
"""Fits an ANN classifier to the data and takes the feature space (weights in the last layer) as embedding.
We train the ANN to classify the different perturbations. After training, the penultimate layer is used as the
feature space, resulting in one embedding per cell.
See here https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7289078/ (Dose-response analysis) and Sup 17-19.
We use either the coefficients of the model for each perturbation as a feature or train a classifier example
Expand All @@ -38,19 +150,19 @@ def load( # type: ignore
test_split_size: float = 0.2,
validation_split_size: float = 0.25,
):
"""Creates a neural network model using the specified parameters (hidden_dim, dropout, batch_norm). Further
parameters such as the number of classes to predict (number of perturbations) are obtained from the provided
AnnData object directly.
"""Creates a classifier model and dataloaders required for training and testing.
A model is created using the specified parameters (hidden_dim, dropout, batch_norm). Further parameters such as
the number of classes to predict (number of perturbations) are obtained from the provided AnnData object directly.
It further creates dataloaders and fixes class imbalance due to control.
Sets the device to a GPU if available.
Args:
adata: AnnData object of size cells x genes
target_col: .obs column that stores the perturbations. Defaults to "perturbations".
layer_key: Layer in adata to use. Defaults to None.
hidden_dim: list of hidden layers of the neural network. For instance: [512, 256].
dropout: amount of dropout applied, constant for all layers. Defaults to 0.
hidden_dim: List of hidden layers of the neural network. For instance: [512, 256].
dropout: Amount of dropout applied, constant for all layers. Defaults to 0.
batch_norm: Whether to apply batch normalization. Defaults to True.
batch_size: The batch size, i.e. the number of datapoints to use in one forward/backward pass. Defaults to 256.
test_split_size: Fraction of data to put in the test set. Default to 0.2.
Expand All @@ -61,7 +173,7 @@ def load( # type: ignore
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
"""
if layer_key is not None and layer_key not in adata.obs.columns:
Expand Down Expand Up @@ -125,18 +237,19 @@ def load( # type: ignore
return self

def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int = 2):
"""Trains and tests the neural network model defined in the load step.
"""Trains and tests the ANN model defined in the load step.
Args:
max_epochs: max epochs for training. Default to 40.
val_epochs_check: test performance on validation dataset after every val_epochs_check training epochs.
patience: number of validation performance checks without improvement, after which the early stopping flag
is activated and training is therefore stopped.
max_epochs: Maximum number of epochs for training. Defaults to 40.
val_epochs_check: Test performance on validation dataset after every val_epochs_check training epochs.
Defaults to 5.
patience: Number of validation performance checks without improvement, after which the early stopping flag
is activated and training is therefore stopped. Defaults to 2.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
>>> dcs.train(max_epochs=5)
"""
Expand All @@ -149,31 +262,32 @@ def train(self, max_epochs: int = 40, val_epochs_check: int = 5, patience: int =
accelerator="auto",
)

self.model = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)
self.mlp = PerturbationClassifier(model=self.net, batch_size=self.train_dataloader.batch_size)

self.trainer.fit(
model=self.model, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader
)
self.trainer.test(model=self.model, dataloaders=self.test_dataloader)
self.trainer.fit(model=self.mlp, train_dataloaders=self.train_dataloader, val_dataloaders=self.valid_dataloader)
self.trainer.test(model=self.mlp, dataloaders=self.test_dataloader)

def get_embeddings(self) -> AnnData:
"""Obtain the embeddings of the data, i.e., the values in the last layer of the MLP.
"""Obtain the embeddings of the data.
The embeddings correspond to the values in the last layer of the MLP. You will get one embedding per cell,
so be aware that you might need to apply another perturbation space to aggregate the embeddings per perturbation.
Returns:
AnnData whose `X` attribute is the perturbation embedding and whose .obs['perturbations'] are the names of the perturbations.
Examples:
>>> import pertpy as pt
>>> adata = pt.dt.papalexi_2021()["rna"]
>>> dcs = pt.tl.DiscriminatorClassifierSpace()
>>> dcs = pt.tl.MLPClassifierSpace()
>>> dcs.load(adata, target_col="gene_target")
>>> dcs.train()
>>> embeddings = dcs.get_embeddings()
"""
with torch.no_grad():
self.model.eval()
self.mlp.eval()
for dataset_count, batch in enumerate(self.entire_dataset):
emb, y = self.model.get_embeddings(batch)
emb, y = self.mlp.get_embeddings(batch)
emb = torch.squeeze(emb)
batch_adata = AnnData(X=emb.cpu().numpy())
batch_adata.obs["perturbations"] = y
Expand Down
4 changes: 3 additions & 1 deletion pertpy/tools/_perturbation_space/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def compute(

if keep_obs: # Save the values of the obs columns of interest in the ps_adata object
obs_df = adata.obs
obs_df = obs_df.groupby(target_col).agg(lambda x: np.nan if len(set(x)) != 1 else list(set(x))[0])
obs_df = obs_df.groupby(target_col).agg(
lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0]
)
for obs_name in obs_df.columns:
if not obs_df[obs_name].isnull().values.any():
mapping = {pert: obs_df.loc[pert][obs_name] for pert in index}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
import pandas as pd
import pertpy as pt
import pytest
from anndata import AnnData


def test_discriminator_classifier():
@pytest.fixture
def adata():
X = np.zeros((20, 5), dtype=np.float32)

pert_index = [
Expand Down Expand Up @@ -42,10 +44,17 @@ def test_discriminator_classifier():

adata = AnnData(X, obs=obs)

# Compute the embeddings using the classifier
ps = pt.tl.DiscriminatorClassifierSpace()
# Add a obs annotations to the adata
adata.obs["MoA"] = ["Growth" if pert == "target1" else "Unknown" for pert in adata.obs["perturbations"]]
adata.obs["Partial Annotation"] = ["Anno1" if pert == "target2" else np.nan for pert in adata.obs["perturbations"]]

return adata


def test_mlp_classifier_space(adata):
ps = pt.tl.MLPClassifierSpace()
classifier_ps = ps.load(adata, hidden_dim=[128])
classifier_ps.train(max_epochs=5)
classifier_ps.train(max_epochs=2)
pert_embeddings = classifier_ps.get_embeddings()

# The embeddings should cluster in 3 perfects clusters since the perturbations are easily separable
Expand All @@ -56,3 +65,19 @@ def test_discriminator_classifier():
np.testing.assert_allclose(results["nmi"], 0.99, rtol=0.1)
np.testing.assert_allclose(results["ari"], 0.99, rtol=0.1)
np.testing.assert_allclose(results["asw"], 0.99, rtol=0.1)


def test_regression_classifier_space(adata):
ps = pt.tl.LRClassifierSpace()
pert_embeddings = ps.compute(adata)

assert pert_embeddings.shape == (3, 5)
assert pert_embeddings.obs[pert_embeddings.obs["perturbations"] == "target1"]["MoA"].values == "Growth"
assert "Partial Annotation" not in pert_embeddings.obs_names
# The classifier should be able to distinguish control and target2 from the respective other two classes
assert np.all(
pert_embeddings.obs[pert_embeddings.obs["perturbations"].isin(["control", "target2"])][
"classifier_score"
].values
== 1.0
)

0 comments on commit 80ef0f0

Please sign in to comment.