Skip to content

Commit

Permalink
Backport PR #2063: Add option for a linear classifer in scANVI (#2084)
Browse files Browse the repository at this point in the history
Co-authored-by: Martin Kim <[email protected]>
  • Loading branch information
meeseeksmachine and martinkim0 authored Jun 2, 2023
1 parent 50b4da8 commit e906264
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 18 deletions.
1 change: 1 addition & 0 deletions docs/release_notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ is available in the [commit logs](https://github.com/YosefLab/scvi-tools/commits
- Raise warning if MPS backend is selected for PyTorch, see https://github.com/pytorch/pytorch/issues/77764 {pr}`2045`.
- Add `deregister_manager` function to {class}`scvi.model.base.BaseModelClass`, allowing to clear
{class}`scvi.data.AnnDataManager` instances from memory {pr}`2060`.
- Add option to use a linear classifier in {class}`scvi.model.SCANVI` {pr}`2063`.
- Add lower bound 0.12.1 for Numpyro dependency {pr}`2078`.
- Add new section in scBasset tutorial for motif scoring {pr}`2079`.

Expand Down
5 changes: 5 additions & 0 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass):
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
* ``'poisson'`` - Poisson distribution
linear_classifier
If `True`, uses a single linear layer for classification instead of a
multi-layer perceptron.
**model_kwargs
Keyword args for :class:`~scvi.module.SCANVAE`
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
linear_classifier: bool = False,
**model_kwargs,
):
super().__init__(adata)
Expand Down Expand Up @@ -152,6 +156,7 @@ def __init__(
use_size_factor_key=use_size_factor_key,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
linear_classifier=linear_classifier,
**scanvae_model_kwargs,
)
self.module.minified_data_type = self.minified_data_type
Expand Down
41 changes: 25 additions & 16 deletions scvi/module/_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ class Classifier(nn.Module):
n_input
Number of input dimensions
n_hidden
Number of hidden nodes in hidden layer
Number of nodes in hidden layer(s). If `0`, the classifier only consists of a
single linear layer.
n_labels
Numput of outputs dimensions
n_layers
Number of hidden layers
Number of hidden layers. If `0`, the classifier only consists of a single
linear layer.
dropout_rate
dropout_rate for nodes
logits
Expand Down Expand Up @@ -45,20 +47,27 @@ def __init__(
):
super().__init__()
self.logits = logits
layers = [
FCLayers(
n_in=n_input,
n_out=n_hidden,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
activation_fn=activation_fn,
**kwargs,
),
nn.Linear(n_hidden, n_labels),
]
layers = []

if n_hidden > 0 and n_layers > 0:
layers.append(
FCLayers(
n_in=n_input,
n_out=n_hidden,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
use_batch_norm=use_batch_norm,
use_layer_norm=use_layer_norm,
activation_fn=activation_fn,
**kwargs,
)
)
else:
n_hidden = n_input

layers.append(nn.Linear(n_hidden, n_labels))

if not logits:
layers.append(nn.Softmax(dim=-1))

Expand Down
11 changes: 9 additions & 2 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,16 @@ class SCANVAE(VAE):
Label group designations
use_labels_groups
Whether to use the label groups
linear_classifier
If `True`, uses a single linear layer for classification instead of a
multi-layer perceptron.
classifier_parameters
Keyword arguments passed into :class:`~scvi.module.Classifier`.
use_batch_norm
Whether to use batch norm in layers
use_layer_norm
Whether to use layer norm in layers
linear_classifier
**vae_kwargs
Keyword args for :class:`~scvi.module.VAE`
"""
Expand All @@ -89,6 +95,7 @@ def __init__(
y_prior=None,
labels_groups: Sequence[int] = None,
use_labels_groups: bool = False,
linear_classifier: bool = False,
classifier_parameters: Optional[dict] = None,
use_batch_norm: Tunable[Literal["encoder", "decoder", "none", "both"]] = "both",
use_layer_norm: Tunable[Literal["encoder", "decoder", "none", "both"]] = "none",
Expand Down Expand Up @@ -120,8 +127,8 @@ def __init__(
self.n_labels = n_labels
# Classifier takes n_latent as input
cls_parameters = {
"n_layers": n_layers,
"n_hidden": n_hidden,
"n_layers": 0 if linear_classifier else n_layers,
"n_hidden": 0 if linear_classifier else n_hidden,
"dropout_rate": dropout_rate,
}
cls_parameters.update(classifier_parameters)
Expand Down
18 changes: 18 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,24 @@ def test_scanvi(save_path):
scanvi_model.train(1)


def test_linear_classifier_scanvi(n_latent: int = 10, n_labels: int = 5):
adata = synthetic_iid(n_labels=n_labels)
SCANVI.setup_anndata(
adata,
"labels",
"label_0",
batch_key="batch",
)
model = SCANVI(adata, linear_classifier=True, n_latent=n_latent)

assert len(model.module.classifier.classifier) == 2 # linear layer + softmax
assert isinstance(model.module.classifier.classifier[0], torch.nn.Linear)
assert model.module.classifier.classifier[0].in_features == n_latent
assert model.module.classifier.classifier[0].out_features == n_labels - 1

model.train(max_epochs=1)


def test_linear_scvi(save_path):
adata = synthetic_iid()
adata = adata[:, :10].copy()
Expand Down

0 comments on commit e906264

Please sign in to comment.