Skip to content

Commit

Permalink
Added scanvi support, including CZI datamodule fix for it
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis committed Oct 15, 2024
1 parent f94f7fa commit 962f043
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 46 deletions.
126 changes: 119 additions & 7 deletions src/scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from anndata import AnnData

import scvi
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data._constants import (
Expand Down Expand Up @@ -44,6 +45,7 @@
from typing import Literal

from anndata import AnnData
from lightning import LightningDataModule

from scvi._types import MinifiedDataType
from scvi.data.fields import (
Expand Down Expand Up @@ -127,12 +129,13 @@ def __init__(
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
linear_classifier: bool = False,
datamodule: LightningDataModule | None = None,
**model_kwargs,
):
super().__init__(adata, registry)
scanvae_model_kwargs = dict(model_kwargs)

self._set_indices_and_labels()
self._set_indices_and_labels(datamodule)

# ignores unlabeled catgegory
n_labels = self.summary_stats.n_labels - 1
Expand Down Expand Up @@ -268,17 +271,21 @@ def from_scvi_model(

return scanvi_model

def _set_indices_and_labels(self):
def _set_indices_and_labels(self, datamodule=None):
"""Set indices for labeled and unlabeled cells."""
labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY)
self.original_label_key = labels_state_registry.original_key
self.unlabeled_category_ = labels_state_registry.unlabeled_category

labels = get_anndata_attribute(
self.adata,
self.adata_manager.data_registry.labels.attr_name,
self.original_label_key,
).ravel()
if datamodule is None:
labels = get_anndata_attribute(
self.adata,
self.adata_manager.data_registry.labels.attr_name,
self.original_label_key,
).ravel()
else:
# for CZI:
labels = list(datamodule.datapipe.map(lambda x: x["label"]))
self._label_mapping = labels_state_registry.categorical_mapping

# set unlabeled and labeled indices
Expand Down Expand Up @@ -500,6 +507,111 @@ def setup_anndata(
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

@classmethod
@setup_anndata_dsp.dedent
def setup_datamodule(
cls,
datamodule: LightningDataModule | None = None,
source_registry=None,
layer: str | None = None,
batch_key: list[str] | None = None,
labels_key: str | None = None,
size_factor_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
):
"""%(summary)s.
Parameters
----------
%(param_datamodule)s
%(param_source_registry)s
%(param_layer)s
%(param_batch_key)s
%(param_size_factor_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
"""
if datamodule.__class__.__name__ == "CensusSCVIDataModule":
# CZI
batch_mapping = datamodule.datapipe.obs_encoders["batch"].classes_
labels_mapping = datamodule.datapipe.obs_encoders["label"].classes_
features_names = list(
datamodule.datapipe.var_query.coords[0]
if datamodule.datapipe.var_query is not None
else range(datamodule.n_vars)
)
n_batch = datamodule.n_batch
n_label = datamodule.n_label

else:
# Anndata -> CZI
# if we are here and datamodule is actually an AnnData object
# it means we init the custom dataloder model with anndata
batch_mapping = source_registry["field_registries"]["batch"]["state_registry"][
"categorical_mapping"
]
labels_mapping = source_registry["field_registries"]["label"]["state_registry"][
"categorical_mapping"
]
features_names = datamodule.var.soma_joinid.values
n_batch = source_registry["field_registries"]["batch"]["summary_stats"]["n_batch"]
n_label = 1 # need to change

datamodule.registry = {
"scvi_version": scvi.__version__,
"model_name": "SCVI",
"setup_args": {
"layer": layer,
"batch_key": batch_key,
"labels_key": labels_key,
"size_factor_key": size_factor_key,
"categorical_covariate_keys": categorical_covariate_keys,
"continuous_covariate_keys": continuous_covariate_keys,
},
"field_registries": {
"X": {
"data_registry": {"attr_name": "X", "attr_key": None},
"state_registry": {
"n_obs": datamodule.n_obs,
"n_vars": datamodule.n_vars,
"column_names": [str(i) for i in features_names],
},
"summary_stats": {"n_vars": datamodule.n_vars, "n_cells": datamodule.n_obs},
},
"batch": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"},
"state_registry": {
"categorical_mapping": batch_mapping,
"original_key": "batch",
},
"summary_stats": {"n_batch": n_batch},
},
"labels": {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_labels"},
"state_registry": {
"categorical_mapping": labels_mapping,
"original_key": "label",
"unlabeled_category": datamodule.unlabeled_category,
},
"summary_stats": {"n_labels": n_label},
},
"size_factor": {"data_registry": {}, "state_registry": {}, "summary_stats": {}},
"extra_categorical_covs": {
"data_registry": {},
"state_registry": {},
"summary_stats": {"n_extra_categorical_covs": 0},
},
"extra_continuous_covs": {
"data_registry": {},
"state_registry": {},
"summary_stats": {"n_extra_continuous_covs": 0},
},
},
"setup_method_name": "setup_datamodule",
}

@staticmethod
def _get_fields_for_adata_minification(
minified_data_type: MinifiedDataType,
Expand Down
7 changes: 5 additions & 2 deletions src/scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import Literal

from anndata import AnnData
from lightning import LightningDataModule

from scvi._types import MinifiedDataType
from scvi.data.fields import (
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson", "normal"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
datamodule: LightningDataModule | None = None,
**kwargs,
):
super().__init__(adata, registry)
Expand Down Expand Up @@ -233,7 +235,7 @@ def setup_anndata(
@setup_anndata_dsp.dedent
def setup_datamodule(
cls,
datamodule,
datamodule: LightningDataModule | None = None,
source_registry=None,
layer: str | None = None,
batch_key: list[str] | None = None,
Expand All @@ -247,7 +249,8 @@ def setup_datamodule(
Parameters
----------
%(param_adata)s
%(param_datamodule)s
%(param_source_registry)s
%(param_layer)s
%(param_batch_key)s
%(param_labels_key)s
Expand Down
Loading

0 comments on commit 962f043

Please sign in to comment.