diff --git a/src/scvi_hub_models/__main__.py b/src/scvi_hub_models/__main__.py index f39cf3e..dc083eb 100644 --- a/src/scvi_hub_models/__main__.py +++ b/src/scvi_hub_models/__main__.py @@ -12,7 +12,15 @@ @click.option("--dry_run", type=bool, default=False, help="Dry run the workflow.") @click.option("--config_key", type=str, help="Use a different config file, e.g. for test purpose.") @click.option("--save_dir", type=str, help="Directory to save intermediate results (defaults temporary).") -def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_dir: str = None) -> None: +@click.option("--reload_data", type=bool, help="Reload the data or get from DVC.") +@click.option("--reload_model", type=bool, help="Reload the model or get from DVC.") +def run_workflow( + model_name: str, + dry_run: bool, + config_key: str = None, + save_dir: str = None, + reload_data: bool = False, + reload_model: bool = False) -> None: """Run the workflow for a specific model.""" from importlib import import_module if not config_key: @@ -22,7 +30,7 @@ def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_di Workflow = workflow_module._Workflow config = json_data_store[config_key] - workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config) + workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config, reload_data=reload_data, reload_model=reload_model) workflow.run() diff --git a/src/scvi_hub_models/config/haniffa_covid_pbmc.json b/src/scvi_hub_models/config/haniffa_covid_pbmc.json index fe5bb16..12c9e50 100644 --- a/src/scvi_hub_models/config/haniffa_covid_pbmc.json +++ b/src/scvi_hub_models/config/haniffa_covid_pbmc.json @@ -2,9 +2,11 @@ "model_dir": "haniffa_covid_pbmc", "model_class": "TOTALVI", "repo_name": "scvi-tools/haniffa_covid_pbmc_totalvi", + "reload_data": true, "extra_data_kwargs": { "reference_adata_cxg_id": "c7775e88-49bf-4ba2-a03b-93f00447c958", - "reference_adata_fname": "haniffa_covid_pbmc.h5ad" + "reference_adata_fname": "haniffa_covid_pbmc.h5ad", + "large_training_file_name": "haniffa_covid_pbmc.h5mu" }, "metadata": { "training_data_url": "https://datasets.cellxgene.cziscience.com/5ad66a4f-d619-4cb3-8015-a87c755647b3.h5ad", @@ -16,6 +18,7 @@ "description": "CITE-seq to measure RNA and surface proteins in thymocytes from wild-type and T cell lineage-restricted mice to generate a comprehensive timeline of cell state for each T cell lineage.", "references": "Steier, Z., Aylard, D.A., McIntyre, L.L. et al. Single-cell multiomic analysis of thymocyte development reveals drivers of CD4+ T cell and CD8+ T cell lineage commitment. Nat Immunol 24, 1579–1590 (2023). https://doi.org/10.1038/s41590-023-01584-0." }, + "criticism_settings": { "n_samples": 3, "cell_type_key": "cell_type" diff --git a/src/scvi_hub_models/config/heart_cell_atlas.json b/src/scvi_hub_models/config/heart_cell_atlas.json index 71ee449..70a77f5 100644 --- a/src/scvi_hub_models/config/heart_cell_atlas.json +++ b/src/scvi_hub_models/config/heart_cell_atlas.json @@ -2,6 +2,10 @@ "model_dir": "heart_cell_atlas_scvi", "model_class": "SCVI", "repo_name": "scvi-tools/heart-cell-atlas-scvi", + "extra_data_kwargs": { + "reference_adata_fname": "heart_cell_atlas.h5ad", + "large_training_file_name": "heart_cell_atlas.h5ad" + }, "metadata": { "training_data_url": "https://www.heartcellatlas.org/#DataSources", "tissues": ["heart"], diff --git a/src/scvi_hub_models/config/human_lung_cell_atlas.json b/src/scvi_hub_models/config/human_lung_cell_atlas.json index 3054b4b..a26c61e 100644 --- a/src/scvi_hub_models/config/human_lung_cell_atlas.json +++ b/src/scvi_hub_models/config/human_lung_cell_atlas.json @@ -1,5 +1,5 @@ { - "model_dir": "hlca_scanvi_reference", + "model_dir": "hlca_reference_scanvi", "model_class": "SCANVI", "repo_name": "scvi-tools/human-lung-cell-atlas-scanvi", "extra_data_kwargs": { @@ -7,7 +7,8 @@ "legacy_model_hash": "a7cd60f4342292b3cba54545bcd8a34decdc8e6b82163f009273d543e7e3910e", "legacy_model_dir": "hlca_scanvi_reference_legacy", "reference_adata_cxg_id": "066943a2-fdac-4b29-b348-40cede398e4e", - "reference_adata_fname": "hlca_core.h5ad" + "reference_adata_fname": "hlca_core.h5ad", + "large_training_file_name": "hlca_core.h5ad" }, "metadata": { "training_data_url": "https://cellxgene.cziscience.com/collections/6f6d381a-7701-4781-935c-db10d30de293", diff --git a/src/scvi_hub_models/config/mouse_thymus_cite.json b/src/scvi_hub_models/config/mouse_thymus_cite.json index 06b590a..5c2ca42 100644 --- a/src/scvi_hub_models/config/mouse_thymus_cite.json +++ b/src/scvi_hub_models/config/mouse_thymus_cite.json @@ -1,11 +1,10 @@ { - "model_dir": "mouse_thymus_cite", + "model_dir": "mouse_thymus_cite_totalvi", "model_class": "TOTALVI", - "repo_name": "scvi-tools/mouse_thymus_totalvi", - "reload_data": true, + "repo_name": "scvi-tools/mouse_thymus_cite_totalvi", "extra_data_kwargs": { "reference_adata_cxg_id": "c14c54f8-85d8-45db-9de7-6ab572cc748a", - "reference_adata_fname": "thymus_cite.h5ad", + "reference_adata_fname": "mouse_thymus_cite.h5ad", "large_training_file_name": "mouse_thymus_cite.h5mu" }, "metadata": { diff --git a/src/scvi_hub_models/config/neurips_cite.json b/src/scvi_hub_models/config/neurips_cite.json index 8d7eec6..e8cf6e0 100644 --- a/src/scvi_hub_models/config/neurips_cite.json +++ b/src/scvi_hub_models/config/neurips_cite.json @@ -1,9 +1,12 @@ { - "model_dir": "bone_marrow_cite", + "model_dir": "bone_marrow_cite_totalvi", "model_class": "TOTALVI", "repo_name": "scvi-tools/bone_marrow_cite_totalvi", "extra_data_kwargs": { - "reference_adata_fname": "bmmc_cite.h5ad" + "url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz", + "hash": "b9b50fade9349719cba23c97c6515d3501a32ee3735fe95fe51221d2e8a5f361", + "reference_adata_fname": "bmmc_cite.h5ad.gz", + "large_training_file_name": "neurips_bone_marrow_cite.h5mu" }, "metadata": { "training_data_url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz", diff --git a/src/scvi_hub_models/config/test_scvi.json b/src/scvi_hub_models/config/test_scvi.json index 3ba0f83..104b9e1 100644 --- a/src/scvi_hub_models/config/test_scvi.json +++ b/src/scvi_hub_models/config/test_scvi.json @@ -2,9 +2,11 @@ "model_dir": "test_scvi", "model_class": "SCVI", "repo_name": "scvi-tools/test-scvi", + "extra_data_kwargs": { + "large_training_file_name": "test_data.h5ad" + }, "collection_name": "test", "minify_model": false, - "extra_data_kwargs": {}, "metadata": { "tissues": ["synthetic"], "data_modalities": ["rna"], diff --git a/src/scvi_hub_models/models/_base_workflow.py b/src/scvi_hub_models/models/_base_workflow.py index 08d19c8..0787577 100644 --- a/src/scvi_hub_models/models/_base_workflow.py +++ b/src/scvi_hub_models/models/_base_workflow.py @@ -14,28 +14,9 @@ from scvi.hub import HubMetadata, HubModel, HubModelCardHelper from scvi.model.base import BaseModelClass -import subprocess, os -from pydrive.auth import GoogleAuth -from pydrive.drive import GoogleDrive - - -def upload_gdrive(path): - remote_url = subprocess.check_output(["dvc", "remote", "list"], text=True).split()[-1] - - # Extract the Google Drive folder ID from the remote URL - folder_id = remote_url.split("gdrive://")[-1] - - # Check if DVC detected an update - if "modified" in subprocess.check_output(["dvc", "diff", "--json"], text=True): - GoogleAuth().LocalWebserverAuth() - drive = GoogleDrive(GoogleAuth()) - file = drive.CreateFile({"title": os.path.basename(path), "parents": [{"id": folder_id}]}) - file.SetContentFile(path) - file.Upload() - print(f"Uploaded {path} to Google Drive folder {folder_id}.") - # Specify your repository and target file -repo_path = "." + +repo_path = os.path.abspath(Path(__file__).parent.parent.parent.parent) dvc_repo = Repo(repo_path) git_repo = git.Repo(repo_path) @@ -67,17 +48,25 @@ class BaseModelWorkflow: config A :class:`~frozendict.frozendict` containing the configuration for the workflow. Can only be set once. + reload_data + If ``True``, the data will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``. + reload_model + If ``True``, the model will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``. """ def __init__( self, save_dir: str | None = None, dry_run: bool = False, - config: frozendict | None = None + config: frozendict | None = None, + reload_data: bool = True, + reload_model: bool = True, ): self.save_dir = save_dir self.dry_run = dry_run self.config = config + self.reload_data = reload_data + self.reload_model = reload_model @property def save_dir(self): @@ -114,22 +103,41 @@ def config(self, value: frozendict): value = frozendict(value) self._config = value + @property + def reload_data(self): + return self._reload_data + + @reload_data.setter + def reload_data(self, value: bool): + if hasattr(self, "_reload_data"): + raise AttributeError("`reload_data` can only be set once.") + self._reload_data = value + + @property + def reload_model(self): + return self._reload_model + + @reload_model.setter + def reload_model(self, value: bool): + if hasattr(self, "_reload_model"): + raise AttributeError("`reload_model` can only be set once.") + self._reload_model = value + def get_adata(self) -> anndata.AnnData | None: """Download and load the dataset.""" logger.info("Loading dataset.") if self.dry_run: return None - if self.config['reload_data']: - path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name']) - adata = self.download_adata() + if self.reload_data: + path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name']) + print(path_file) + adata = self.download_adata(path_file) dvc_repo.add(path_file) git_repo.index.commit(f"Track {path_file} with DVC") - print(f"Pushing {path_file} to DVC remote...") dvc_repo.push() git_repo.remote().push() - upload_gdrive(path_file) else: - path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name']) + path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name']) dvc_repo.pull([path_file]) if path_file.endswith(".h5mu"): adata = mudata.read_h5mu(path_file) @@ -137,34 +145,40 @@ def get_adata(self) -> anndata.AnnData | None: adata = anndata.read_h5ad(path_file) return adata - def _get_adata(self, url: str, hash: str, file_path: str) -> str: + def get_model(self, adata) -> BaseModelClass | None: + """Download and load the model.""" + logger.info("Loading model.") + if self.dry_run: + return None + if self.reload_model: + path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir']) + model = self.load_model(adata) + model.save(path_file, overwrite=True, save_anndata=False) + dvc_repo.add(path_file) + git_repo.index.commit(f"Track {path_file} with DVC") + dvc_repo.push() + git_repo.remote().push() + else: + path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir']) + dvc_repo.pull([path_file]) + model = self.default_load_model(adata, self.config['model_class'], path_file) + return model + + def _get_adata(self, url: str, hash: str, file_path: str, processor: str | None = None) -> str: logger.info("Downloading and reading data.") if self.dry_run: return None - retrieve( + file_out = retrieve( url=url, known_hash=hash, fname=file_path, path=self.save_dir, - processor=None, - ) - return anndata.read_h5ad(os.path.join(self.save_dir, file_path)) - - def _download_model(self, url: str, hash: str, file_path: str) -> str: - logger.info("Downloading model.") - if self.dry_run: - return None - - return retrieve( - url=url, # - known_hash=hash, #config["adata_hashes"][tissue], - fname=file_path, # f"{tissue}_adata.h5ad" - path=self.save_dir, - processor=None, + processor=processor, ) + return anndata.read_h5ad(file_out) - def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str): + def default_load_model(self, adata: anndata.AnnData, model_name: str, model_path: str | None = None) -> BaseModelClass: """Load the model.""" logger.info("Loading model.") if self.dry_run: @@ -178,48 +192,20 @@ def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str): elif model_name == "CondSCVI": from scvi.model import CondSCVI model_cls = CondSCVI + elif model_name == "TOTALVI": + from scvi.model import TOTALVI + model_cls = TOTALVI elif model_name == "Stereoscope": from scvi.external import RNAStereoscope model_cls = RNAStereoscope else: raise ValueError(f"Model {model_name} not recognized.") - model = model_cls.load(os.path.join(self.save_dir, model_path), adata=adata) + if model_path is None: + model_path = os.path.join(self.save_dir, self.config["model_dir"]) + model = model_cls.load(model_path, adata=adata) return model - def _create_hub_model( - self, - model_path: str, - training_data_url: str | None = None - ) -> HubModel | None: - logger.info("Creating the HubModel.") - if self.dry_run: - return None - - if training_data_url is None: - training_data_url = self.config.get("training_data_url", None) - - metadata = self.config["metadata"] - hub_metadata = HubMetadata.from_dir( - model_path, - anndata_version=anndata_version - ) - model_card = HubModelCardHelper.from_dir( - model_path, - anndata_version=anndata_version, - license_info=metadata.get("license_info", "mit"), - data_modalities=metadata.get("data_modalities", None), - tissues=metadata.get("tissues", None), - data_is_annotated=metadata.get("data_is_annotated", False), - data_is_minified=metadata.get("data_is_minified", False), - training_data_url=training_data_url, - training_code_url=metadata.get("training_code_url", None), - description=metadata.get("description", None), - references=metadata.get("references", None), - ) - - return HubModel(model_path, hub_metadata, model_card) - def _minify_and_save_model( self, model: BaseModelClass, @@ -252,11 +238,47 @@ def _minify_and_save_model( qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) adata.obsm[qzm_key] = qzm adata.obsm[qzv_key] = qzv - model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) + if isinstance(adata, mudata.MuData): + model.minify_mudata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) + else: + model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key) model.save(mini_model_path, overwrite=True, save_anndata=True) return mini_model_path + def _create_hub_model( + self, + model_path: str, + training_data_url: str | None = None + ) -> HubModel | None: + logger.info("Creating the HubModel.") + if self.dry_run: + return None + + if training_data_url is None: + training_data_url = self.config.get("training_data_url", None) + + metadata = self.config["metadata"] + hub_metadata = HubMetadata.from_dir( + model_path, + anndata_version=anndata_version + ) + model_card = HubModelCardHelper.from_dir( + model_path, + anndata_version=anndata_version, + license_info=metadata.get("license_info", "mit"), + data_modalities=metadata.get("data_modalities", None), + tissues=metadata.get("tissues", None), + data_is_annotated=metadata.get("data_is_annotated", False), + data_is_minified=metadata.get("data_is_minified", False), + training_data_url=training_data_url, + training_code_url=metadata.get("training_code_url", None), + description=metadata.get("description", None), + references=metadata.get("references", None), + ) + + return HubModel(model_path, hub_metadata, model_card) + def _upload_hub_model(self, hub_model: HubModel, repo_name: str | None = None, **kwargs) -> HubModel: """Upload the HubModel to Hugging Face.""" collection_name = self.config.get("collection_name", None) diff --git a/src/scvi_hub_models/models/_haniffa_covid_pbmc.py b/src/scvi_hub_models/models/_haniffa_covid_pbmc.py index dc47600..6262d18 100644 --- a/src/scvi_hub_models/models/_haniffa_covid_pbmc.py +++ b/src/scvi_hub_models/models/_haniffa_covid_pbmc.py @@ -18,15 +18,12 @@ def _load_adata(self) -> AnnData: adata_path = os.path.join(self.save_dir, self.config['extra_data_kwargs']["reference_adata_fname"]) if not os.path.exists(adata_path): - # TODO for next LTX remove census_version='latest'. - download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path, census_version='latest') + download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path) return sc.read_h5ad(adata_path) def _preprocess_adata(self, adata: AnnData) -> AnnData: import scanpy as sc - print(adata, adata.X.data) - sc.pp.filter_genes(adata, min_counts=3) adata.layers["counts"] = adata.X.copy() sc.pp.highly_variable_genes( @@ -38,20 +35,28 @@ def _preprocess_adata(self, adata: AnnData) -> AnnData: batch_key="sample_id", span=1.0, ) - protein_adata = AnnData(adata.obsm["protein_expression"]) + protein_adata = AnnData( + adata.uns['antibody_raw.X'].toarray(), + obs=adata.obs, + var=adata.uns['antibody_features']) + protein_adata.obs_names = adata.obs_names - del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + del adata.uns['antibody_raw.X'] + del adata.uns['antibody_features'] + del adata.uns['antibody_X'] + del adata.uns['neighbors'] + mdata = MuData({"rna": adata, "protein": protein_adata}) - return adata + return mdata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") if self.dry_run: return None adata = self._load_adata() mdata = self._preprocess_adata(adata) + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: @@ -59,7 +64,7 @@ def _initialize_model(self, mdata: MuData) -> TOTALVI: mdata, rna_layer="counts", protein_layer=None, - batch_key="sample_id", + batch_key="donor_id", modalities={ "rna_layer": "rna", "protein_layer": "protein", @@ -70,11 +75,11 @@ def _initialize_model(self, mdata: MuData) -> TOTALVI: def _train_model(self, model: TOTALVI) -> TOTALVI: """Train the scVI model.""" - model.train(max_epochs=200) + model.train(max_epochs=50) return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -85,7 +90,7 @@ def get_model(self, adata) -> TOTALVI | None: def run(self): super().run() - mdata = self.load_adata() + mdata = self.get_adata() model = self.get_model(mdata) model_path = self._minify_and_save_model(model, mdata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_heart_cell_atlas.py b/src/scvi_hub_models/models/_heart_cell_atlas.py index 65c3822..abea1b5 100644 --- a/src/scvi_hub_models/models/_heart_cell_atlas.py +++ b/src/scvi_hub_models/models/_heart_cell_atlas.py @@ -31,13 +31,14 @@ def _preprocess_adata(self, adata: AnnData) -> AnnData: return adata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" logger.info(f"Saving heart cell atlas dataset to {self.save_dir}.") if self.dry_run: return None adata = self._load_adata() adata = self._preprocess_adata(adata) + adata.write_h5ad(path) return adata def _initialize_model(self, adata: AnnData) -> SCVI: @@ -51,11 +52,11 @@ def _initialize_model(self, adata: AnnData) -> SCVI: def _train_model(self, model: SCVI) -> SCVI: """Train the scVI model.""" - model.train(max_epochs=5) + model.train(max_epochs=200) return model - def get_model(self, adata) -> SCVI | None: + def load_model(self, adata) -> SCVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -66,7 +67,7 @@ def get_model(self, adata) -> SCVI | None: def run(self): super().run() - adata = self.load_adata() + adata = self.get_adata() model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_human_lung_cell_atlas.py b/src/scvi_hub_models/models/_human_lung_cell_atlas.py index f68baef..6dd8101 100644 --- a/src/scvi_hub_models/models/_human_lung_cell_atlas.py +++ b/src/scvi_hub_models/models/_human_lung_cell_atlas.py @@ -10,6 +10,9 @@ class _Workflow(BaseModelWorkflow): + def load_model(self, adata: anndata.AnnData) -> BaseModelWorkflow: + return self.default_load_model(adata, self.config['model_class']) + def _download_model(self): from pathlib import Path @@ -23,7 +26,6 @@ def _download_model(self): path=self.save_dir, ) untarred = sorted(untarred) - print(untarred) return str(Path(untarred[0]).parent) def _get_model(self) -> str: @@ -64,7 +66,7 @@ def _preprocess_reference_adata(self, adata: anndata.AnnData, model_path: str) - # .X does not contain raw counts initially adata.X = adata.raw.X - _, genes, _, _ = _load_saved_files(model_path, load_adata=False) + _, genes, _, _ = _load_saved_files(os.path.join(self.save_dir, self.config["model_dir"]), load_adata=False) adata = adata[:, adata.var.index.isin(genes)].copy() # get rid of some var columns that we dont need @@ -108,13 +110,14 @@ def _download_embedding_adata(self) -> str: adata = anndata.io.read_h5ad(adata) return adata[adata.obs["core_or_extension"] == "core"].copy() - def _get_adata(self, model_path: str) -> anndata.AnnData: + def download_adata(self, path) -> anndata.AnnData: logging.info("Loading data.") if self.dry_run: return None ref_adata = self._download_reference_adata() - ref_adata = self._preprocess_reference_adata(ref_adata, model_path) + ref_adata = self._preprocess_reference_adata(ref_adata, self.model_path) ref_adata = self._postprocess_reference_adata(ref_adata) + ref_adata.write_h5ad(path) return ref_adata @property @@ -124,9 +127,9 @@ def id(self) -> str: def run(self): super().run() - model_path = self._get_model() - adata = self._get_adata(model_path) - model = self._load_model(model_path, adata, "SCANVI") + self._get_model() + adata = self.get_adata() + model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) hub_model = self._upload_hub_model(hub_model) diff --git a/src/scvi_hub_models/models/_mouse_thymus_cite.py b/src/scvi_hub_models/models/_mouse_thymus_cite.py index 56c1ec8..255b1ee 100644 --- a/src/scvi_hub_models/models/_mouse_thymus_cite.py +++ b/src/scvi_hub_models/models/_mouse_thymus_cite.py @@ -23,28 +23,35 @@ def _load_adata(self) -> AnnData: return sc.read_h5ad(adata_path) def _preprocess_adata(self, adata: AnnData) -> AnnData: - import scanpy as sc - - print(adata, adata.X.data) - - sc.pp.filter_genes(adata, min_counts=3) matching_indices = [adata.raw.var_names.get_loc(gene) for gene in adata.var_names] adata.layers["counts"] = adata.raw.X[:, matching_indices].copy() - protein_adata = AnnData(adata.obsm["protein_expression"]) + sc.pp.highly_variable_genes( + adata, + n_top_genes=4000, + subset=True, + layer="counts", + flavor="seurat_v3", + span=1.0, + ) + protein_adata = AnnData(adata.obsm["protein_expression"], obs=adata.obs) protein_adata.obs_names = adata.obs_names del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + del adata.obsm['denoised_genes'] + del adata.obsm['denoised_proteins'] + del adata.uns['AB_adata'] + mdata = MuData({"rna": adata, "protein": protein_adata}) + print(mdata) - return adata + return mdata - def download_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" - logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") + logger.info(f"Saving dataset to {path} and preprocessing.") if self.dry_run: return None adata = self._load_adata() mdata = self._preprocess_adata(adata) - mdata.write_h5mu(f'data/{self.config['extra_data_kwargs']['large_training_file_name']}') + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: @@ -67,7 +74,7 @@ def _train_model(self, model: TOTALVI) -> TOTALVI: return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: diff --git a/src/scvi_hub_models/models/_neurips_cite.py b/src/scvi_hub_models/models/_neurips_cite.py index 74f6cfd..af4c470 100644 --- a/src/scvi_hub_models/models/_neurips_cite.py +++ b/src/scvi_hub_models/models/_neurips_cite.py @@ -1,9 +1,9 @@ import logging -import os import scanpy as sc from anndata import AnnData from mudata import MuData +from pooch import Decompress from scvi.model import TOTALVI from scvi_hub_models.models import BaseModelWorkflow @@ -12,52 +12,45 @@ class _Workflow(BaseModelWorkflow): - - def _load_adata(self) -> AnnData: - from cellxgene_census import download_source_h5ad - - adata_path = os.path.join(self.save_dir, self.config['extra_data_kwargs']["reference_adata_fname"]) - if not os.path.exists(adata_path): - # TODO for next LTS remove census_version='latest'. - download_source_h5ad(self.config['extra_data_kwargs']["reference_adata_cxg_id"], to_path=adata_path, census_version='latest') - return sc.read_h5ad(adata_path) - def _preprocess_adata(self, adata: AnnData) -> AnnData: - import scanpy as sc - - sc.pp.filter_genes(adata, min_counts=3) - adata.layers["counts"] = adata.X.copy() + rna = adata[:, adata.var['feature_types']=='GEX'].copy() + protein = adata[:, adata.var['feature_types']=='ADT'].copy() + protein.layers["counts"] = protein.layers["counts"].toarray() + sc.pp.filter_genes(rna, min_counts=3) sc.pp.highly_variable_genes( - adata, + rna, n_top_genes=4000, subset=True, layer="counts", flavor="seurat_v3", - batch_key="sample_id", + batch_key="Site", span=1.0, ) - protein_adata = AnnData(adata.obsm["protein_expression"]) - protein_adata.obs_names = adata.obs_names - del adata.obsm["protein_expression"] - adata = MuData({"rna": adata, "protein": protein_adata}) + adata = MuData({"rna": rna, "protein": protein}) return adata - def load_adata(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: """Download and load the dataset.""" - logger.info(f"Saving dataset to {self.save_dir} and preprocessing.") + logger.info(f"Saving dataset to {path} and preprocessing.") if self.dry_run: return None - adata = self._load_adata() + adata = self._get_adata( + url=self.config["extra_data_kwargs"]["url"], + hash=self.config["extra_data_kwargs"]["hash"], + file_path=self.config["extra_data_kwargs"]["reference_adata_fname"], + processor=Decompress(), + ) mdata = self._preprocess_adata(adata) + mdata.write_h5mu(path) return mdata def _initialize_model(self, mdata: MuData) -> TOTALVI: TOTALVI.setup_mudata( mdata, rna_layer="counts", - protein_layer=None, - batch_key="sample_id", + protein_layer="counts", + batch_key="batch", modalities={ "rna_layer": "rna", "protein_layer": "protein", @@ -72,7 +65,7 @@ def _train_model(self, model: TOTALVI) -> TOTALVI: return model - def get_model(self, adata) -> TOTALVI | None: + def load_model(self, adata) -> TOTALVI | None: """Initialize and train the scVI model.""" logger.info("Training the scVI model.") if self.dry_run: @@ -83,7 +76,7 @@ def get_model(self, adata) -> TOTALVI | None: def run(self): super().run() - mdata = self.load_adata() + mdata = self.get_adata() model = self.get_model(mdata) model_path = self._minify_and_save_model(model, mdata) hub_model = self._create_hub_model(model_path) diff --git a/src/scvi_hub_models/models/_test_scvi.py b/src/scvi_hub_models/models/_test_scvi.py index b72913d..7c326b1 100644 --- a/src/scvi_hub_models/models/_test_scvi.py +++ b/src/scvi_hub_models/models/_test_scvi.py @@ -12,29 +12,23 @@ class _Workflow(BaseModelWorkflow): - def load_dataset(self) -> AnnData | None: + def download_adata(self, path) -> AnnData | None: from scvi.data import synthetic_iid logger.info("Loading synthetic dataset.") if self.dry_run: return None + adata = synthetic_iid() + adata.write_h5ad(path) + return adata - return synthetic_iid() - - def initialize_model(self, adata: AnnData | None) -> SCVI | None: - logger.info("Initializing the scVI model.") + def load_model(self, adata: AnnData) -> SCVI: + logger.info("Training the scVI model.") if self.dry_run: return None - SCVI.setup_anndata(adata) - return SCVI(adata) - - def train_model(self, model: SCVI | None) -> SCVI | None: - logger.info("Training the scVI model.") - if self.dry_run: - return model - - model.train(max_epochs=1) + model = SCVI(adata) + model.train(max_epochs=10) return model @property @@ -44,9 +38,8 @@ def id(self) -> str: def run(self): super().run() - adata = self.load_dataset() - model = self.initialize_model(adata) - model = self.train_model(model) + adata = self.get_adata() + model = self.get_model(adata) model_path = self._minify_and_save_model(model, adata) hub_model = self._create_hub_model(model_path) hub_model = self._upload_hub_model(hub_model) diff --git a/src/scvi_hub_models/test.ipynb b/src/scvi_hub_models/test.ipynb new file mode 100644 index 0000000..c730192 --- /dev/null +++ b/src/scvi_hub_models/test.ipynb @@ -0,0 +1,603 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.insert(0, '/home/cane/Documents/scvi-tools/scvi-hub-models/src')" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import os\n", + "\n", + "import scanpy as sc\n", + "from anndata import AnnData\n", + "from mudata import MuData\n", + "from scvi.model import TOTALVI\n", + "\n", + "from scvi_hub_models.models import BaseModelWorkflow\n", + "\n", + "logger = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n" + ] + } + ], + "source": [ + "import mudata\n", + "mu = mudata.read_h5mu('/home/cane/Documents/scvi-tools/scvi-hub-models/test/mini_totalvi/mdata.h5mu')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
MuData object with n_obs × n_vars = 72042 × 4111\n", + " obs:\t'_scvi_labels', 'observed_lib_size'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_adata_minify_type', '_scvi_manager_uuid', '_scvi_uuid'\n", + " obsm:\t'totalvi_latent_qzm', 'totalvi_latent_qzv'\n", + " 2 modalities\n", + " rna:\t72042 x 4000\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'cv_gene'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'hvg', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " varm:\t'lfc_model', 'lfc_raw'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " varm:\t'lfc_model', 'lfc_raw'" + ], + "text/plain": [ + "MuData object with n_obs × n_vars = 72042 × 4111\n", + " obs:\t'_scvi_labels', 'observed_lib_size'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_adata_minify_type', '_scvi_manager_uuid', '_scvi_uuid'\n", + " obsm:\t'totalvi_latent_qzm', 'totalvi_latent_qzv'\n", + " 2 modalities\n", + " rna:\t72042 x 4000\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'cv_gene'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'hvg', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " varm:\t'lfc_model', 'lfc_raw'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " varm:\t'lfc_model', 'lfc_raw'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mu" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "from scvi_hub_models.config import json_data_store" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from scvi_hub_models.models import _neurips_cite\n", + "Workflow = _neurips_cite._Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "self = Workflow(config=json_data_store['neurips_cite'], save_dir='.', reload_model=True, reload_data=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'bmmc_cite.h5ad.gz'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "self.config['extra_data_kwargs']['reference_adata_fname']" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Decompressing '/home/cane/.cache/pooch/bmmc_cite.h5ad.gz' to '/home/cane/.cache/pooch/bmmc_cite.h5ad.gz.decomp' using method 'auto'.\n" + ] + } + ], + "source": [ + "from pooch import retrieve\n", + "from pooch import Decompress\n", + "\n", + "adata_path = retrieve(\n", + " url=self.config['extra_data_kwargs']['url'],\n", + " known_hash=\"b9b50fade9349719cba23c97c6515d3501a32ee3735fe95fe51221d2e8a5f361\",\n", + " fname='bmmc_cite.h5ad.gz',\n", + " processor=Decompress(),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/anndata.py:1758: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", + " utils.warn_names_duplicates(\"var\")\n" + ] + } + ], + "source": [ + "import anndata\n", + "\n", + "ad = anndata.read_h5ad('/home/cane/.cache/pooch/bmmc_cite.h5ad')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AnnData object with n_obs × n_vars = 90261 × 14087\n", + " obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train'\n", + " var: 'feature_types', 'gene_id'\n", + " uns: 'dataset_id', 'genome', 'organism'\n", + " obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap'\n", + " layers: 'counts'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ad" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/cane/Documents/scvi-tools/scvi-hub-models/data/neurips_bone_marrow_cite.h5mu\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/anndata.py:1758: UserWarning: Variable names are not unique. To make them unique, call `.var_names_make_unique`.\n", + " utils.warn_names_duplicates(\"var\")\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:931: UserWarning: Cannot join columns with the same name because var_names are intersecting.\n", + " warnings.warn(\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1531: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"var\", axis=0, join_common=join_common)\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:1429: FutureWarning: From 0.4 .update() will not pull obs/var columns from individual modalities by default anymore. Set mudata.set_options(pull_on_update=False) to adopt the new behaviour, which will become the default. Use new pull_obs/pull_var and push_obs/push_var methods for more flexibility.\n", + " self._update_attr(\"obs\", axis=1, join_common=join_common)\n" + ] + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mdata = self.get_adata()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "GEX_n_genes_by_counts 893\n", + "GEX_pct_counts_mt 6.723979\n", + "GEX_size_factors 0.356535\n", + "GEX_phase G1\n", + "ADT_n_antibodies_by_counts 115\n", + "ADT_total_counts 2828.0\n", + "ADT_iso_count 5.0\n", + "cell_type Naive CD20+ B IGKC+\n", + "batch s1d1\n", + "ADT_pseudotime_order NaN\n", + "GEX_pseudotime_order NaN\n", + "Samplename site1_donor1_cite\n", + "Site site1\n", + "DonorNumber donor1\n", + "Modality cite\n", + "VendorLot 3054455\n", + "DonorID 15078\n", + "DonorAge 34\n", + "DonorBMI 24.8\n", + "DonorBloodType B-\n", + "DonorRace White\n", + "Ethnicity HISPANIC OR LATINO\n", + "DonorGender Male\n", + "QCMeds False\n", + "DonorSmoker Nonsmoker\n", + "is_train train\n", + "_scvi_batch 0\n", + "cv_cell 7.578709\n", + "Name: GCATTAGCATAAGCGG-1-s1d1, dtype: object" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mdata['rna'].obs.iloc[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO Computing empirical prior initialization for protein background. \n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.\n", + " warnings.warn(\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py:316: The lr scheduler dict contains the key(s) ['monitor'], but the keys will be ignored. You need to call `lr_scheduler.step()` manually in manual optimization.\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n", + "/home/cane/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 200/200: 100%|██████████| 200/200 [22:05<00:00, 7.33s/it, v_num=1, train_loss_step=1.58e+3, train_loss_epoch=1.6e+3]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=200` reached.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 200/200: 100%|██████████| 200/200 [22:05<00:00, 6.63s/it, v_num=1, train_loss_step=1.58e+3, train_loss_epoch=1.6e+3]\n" + ] + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['rna', 'protein']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cane/Documents/scvi-tools/src/scvi/criticism/_ppc.py:293: UserWarning: n_top_genes_fallback=100 is greater than 10% of the number ofgenes f(134) in the dataset. Setting it to 10%.\n", + " warnings.warn(\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Only one class present in y_true. ROC AUC score is not defined in that case.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[23], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_model(mdata)\n\u001b[0;32m----> 2\u001b[0m model_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minify_and_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m hub_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_create_hub_model(model_path)\n\u001b[1;32m 4\u001b[0m hub_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_upload_hub_model(hub_model)\n", + "File \u001b[0;32m~/Documents/scvi-tools/scvi-hub-models/src/scvi_hub_models/models/_base_workflow.py:239\u001b[0m, in \u001b[0;36m_minify_and_save_model\u001b[0;34m(self, model, adata)\u001b[0m\n\u001b[1;32m 237\u001b[0m qzm, qzv \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mget_latent_representation(give_mean\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, return_dist\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 238\u001b[0m adata\u001b[38;5;241m.\u001b[39mobsm[qzm_key] \u001b[38;5;241m=\u001b[39m qzm\n\u001b[0;32m--> 239\u001b[0m adata\u001b[38;5;241m.\u001b[39mobsm[qzv_key] \u001b[38;5;241m=\u001b[39m qzv\n\u001b[1;32m 240\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(adata, mudata\u001b[38;5;241m.\u001b[39mMuData):\n\u001b[1;32m 241\u001b[0m model\u001b[38;5;241m.\u001b[39mminify_mudata(use_latent_qzm_key\u001b[38;5;241m=\u001b[39mqzm_key, use_latent_qzv_key\u001b[38;5;241m=\u001b[39mqzv_key)\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:76\u001b[0m, in \u001b[0;36mcreate_criticism_report\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, save_folder)\u001b[0m\n\u001b[1;32m 74\u001b[0m md_cell_wise_cv, md_gene_wise_cv, md_de \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m modalities:\n\u001b[0;32m---> 76\u001b[0m md_cell_wise_cv_, md_gene_wise_cv_, md_de_ \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_metrics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m md_cell_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_cell_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 79\u001b[0m md_gene_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_gene_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:116\u001b[0m, in \u001b[0;36mcompute_metrics\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, modality)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m label_key \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m labels_state_registry\u001b[38;5;241m.\u001b[39moriginal_key \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_scvi_labels\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 115\u001b[0m label_key \u001b[38;5;241m=\u001b[39m labels_state_registry\u001b[38;5;241m.\u001b[39moriginal_key\n\u001b[0;32m--> 116\u001b[0m \u001b[43mppc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdifferential_expression\u001b[49m\u001b[43m(\u001b[49m\u001b[43mde_groupby\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mp_val_thresh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 117\u001b[0m summary_df \u001b[38;5;241m=\u001b[39m ppc\u001b[38;5;241m.\u001b[39mmetrics[METRIC_DIFF_EXP][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msummary\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mset_index(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgroup\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 118\u001b[0m summary_df \u001b[38;5;241m=\u001b[39m summary_df\u001b[38;5;241m.\u001b[39mdrop(columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/utils/_dependencies.py:24\u001b[0m, in \u001b[0;36mdependencies.
MuData object with n_obs × n_vars = 72042 × 16104\n", + " obs:\t'_scvi_labels', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_uuid', '_scvi_manager_uuid'\n", + " 2 modalities\n", + " rna:\t72042 x 15993\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'" + ], + "text/plain": [ + "MuData object with n_obs × n_vars = 72042 × 16104\n", + " obs:\t'_scvi_labels', 'cv_cell'\n", + " var:\t'cv_gene'\n", + " uns:\t'_scvi_uuid', '_scvi_manager_uuid'\n", + " 2 modalities\n", + " rna:\t72042 x 15993\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'\n", + " var:\t'gene_id', 'gene_name', 'expression_type', 'n_cells', 'feature_is_filtered', 'feature_name', 'feature_reference', 'feature_biotype', 'feature_length', 'feature_type', 'n_counts'\n", + " uns:\t'AB_adata', 'annotations_clean_colors', 'batch_condition', 'batch_indices_colors', 'citation', 'leiden', 'neighbors', 'protein_names', 'schema_reference', 'schema_version', 'title', 'totalVI_genes', 'totalVI_proteins', 'umap'\n", + " obsm:\t'X_totalVI', 'X_umap', 'denoised_genes', 'denoised_proteins'\n", + " layers:\t'counts'\n", + " protein:\t72042 x 111\n", + " obs:\t'percent_mito', 'n_counts', 'n_genes', 'n_protein_counts', 'n_proteins', 'leiden_totalVI_res1.4', 'leiden_totalVI_res1.0', 'leiden_totalVI_res0.6', 'annotations_clean', 'mean_pseudotime', 'Pseudotime_bin', 'curve1', 'curve2', 'difference', 'weight_curve1', 'weight_curve2', 'UMIs_RNA', 'UMIs_protein', 'n_genes_pt', 'n_proteins_pt', 'percent_mito_pt', 'Experiment', 'slingshot_clusters', 'organism_ontology_term_id', 'disease_ontology_term_id', 'sex_ontology_term_id', 'tissue_type', 'tissue_ontology_term_id', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'development_stage_ontology_term_id', 'batch_indices', 'sample_id', 'Location', 'donor_id', 'sample_weeks', 'genotype', 'Lineage_by_genotypeSlingshot', 'Lineage_by_genotype', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid', '_scvi_batch'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mdata.update()\n", + "mdata['protein'].obs = mdata['rna'].obs\n", + "mdata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['rna', 'protein']\n" + ] + }, + { + "ename": "KeyError", + "evalue": "\"Values ['rna', 'protein'], from ['rna', 'protein'], are not valid obs/ var names or indices.\"", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_minify_and_save_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmdata\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/scvi-tools/scvi-hub-models/src/scvi_hub_models/models/_base_workflow.py:267\u001b[0m, in \u001b[0;36mBaseModelWorkflow._minify_and_save_model\u001b[0;34m(self, model, adata)\u001b[0m\n\u001b[1;32m 265\u001b[0m os\u001b[38;5;241m.\u001b[39mmakedirs(mini_model_path)\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcreate_criticism_report\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m model\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m SUPPORTED_PPC_MODELS:\n\u001b[0;32m--> 267\u001b[0m \u001b[43mcreate_criticism_report\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43msave_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmini_model_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcriticism_settings\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mn_samples\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcriticism_settings\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcell_type_key\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mminify_model\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;129;01mand\u001b[39;00m model\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m \u001b[38;5;129;01min\u001b[39;00m SUPPORTED_MINIFIED_MODELS:\n\u001b[1;32m 275\u001b[0m qzm_key \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;241m.\u001b[39mlower()\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m_latent_qzm\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:76\u001b[0m, in \u001b[0;36mcreate_criticism_report\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, save_folder)\u001b[0m\n\u001b[1;32m 74\u001b[0m md_cell_wise_cv, md_gene_wise_cv, md_de \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m [modalities]:\n\u001b[0;32m---> 76\u001b[0m md_cell_wise_cv_, md_gene_wise_cv_, md_de_ \u001b[38;5;241m=\u001b[39m \u001b[43mcompute_metrics\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 77\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mskip_metrics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mi\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 78\u001b[0m md_cell_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_cell_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 79\u001b[0m md_gene_wise_cv \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m md_gene_wise_cv_ \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_create_criticism_report.py:99\u001b[0m, in \u001b[0;36mcompute_metrics\u001b[0;34m(model, adata, skip_metrics, n_samples, label_key, modality)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_metrics\u001b[39m(model, adata, skip_metrics, n_samples, label_key, modality\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 98\u001b[0m models_dict \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m: model}\n\u001b[0;32m---> 99\u001b[0m ppc \u001b[38;5;241m=\u001b[39m \u001b[43mPPC\u001b[49m\u001b[43m(\u001b[49m\u001b[43madata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodels_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_samples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_samples\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodality\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodality\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 100\u001b[0m \u001b[38;5;66;03m# run ppc+cv\u001b[39;00m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m METRIC_CV_CELL \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m skip_metrics:\n", + "File \u001b[0;32m~/Documents/scvi-tools/src/scvi/criticism/_ppc.py:85\u001b[0m, in \u001b[0;36mPosteriorPredictiveCheck.__init__\u001b[0;34m(self, adata, models_dict, count_layer_key, n_samples, indices, modality)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(adata, MuData):\n\u001b[1;32m 84\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m modality \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModality must be defined for MuData.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m---> 85\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata \u001b[38;5;241m=\u001b[39m \u001b[43madata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mmodality\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 86\u001b[0m raw_counts \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 87\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata\u001b[38;5;241m.\u001b[39mlayers[count_layer_key] \u001b[38;5;28;01mif\u001b[39;00m count_layer_key \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madata\u001b[38;5;241m.\u001b[39mX\n\u001b[1;32m 89\u001b[0m )\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:516\u001b[0m, in \u001b[0;36mMuData.__getitem__\u001b[0;34m(self, index)\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmod[index]\n\u001b[1;32m 515\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mMuData\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mas_view\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:166\u001b[0m, in \u001b[0;36mMuData.__init__\u001b[0;34m(self, data, feature_types_names, as_view, index, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_init_common()\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m as_view:\n\u001b[0;32m--> 166\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_init_as_view\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 167\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# Add all modalities to a MuData object\u001b[39;00m\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/mudata/_core/mudata.py:265\u001b[0m, in \u001b[0;36mMuData._init_as_view\u001b[0;34m(self, mudata_ref, index)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_init_as_view\u001b[39m(\u001b[38;5;28mself\u001b[39m, mudata_ref: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMuData\u001b[39m\u001b[38;5;124m\"\u001b[39m, index):\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01manndata\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01m_core\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mindex\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m _normalize_indices\n\u001b[0;32m--> 265\u001b[0m obsidx, varidx \u001b[38;5;241m=\u001b[39m \u001b[43m_normalize_indices\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmudata_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmudata_ref\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# to handle single-element subsets, otherwise when subsetting a Dataframe\u001b[39;00m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;66;03m# we get a Series\u001b[39;00m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obsidx, Integral):\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/index.py:32\u001b[0m, in \u001b[0;36m_normalize_indices\u001b[0;34m(index, names0, names1)\u001b[0m\n\u001b[1;32m 30\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mtuple\u001b[39m(i\u001b[38;5;241m.\u001b[39mvalues \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, pd\u001b[38;5;241m.\u001b[39mSeries) \u001b[38;5;28;01melse\u001b[39;00m i \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m index)\n\u001b[1;32m 31\u001b[0m ax0, ax1 \u001b[38;5;241m=\u001b[39m unpack_index(index)\n\u001b[0;32m---> 32\u001b[0m ax0 \u001b[38;5;241m=\u001b[39m \u001b[43m_normalize_index\u001b[49m\u001b[43m(\u001b[49m\u001b[43max0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnames0\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 33\u001b[0m ax1 \u001b[38;5;241m=\u001b[39m _normalize_index(ax1, names1)\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ax0, ax1\n", + "File \u001b[0;32m~/.local/share/hatch/env/virtual/scvi-tools/eVVa01t5/scvi-tools/lib/python3.12/site-packages/anndata/_core/index.py:99\u001b[0m, in \u001b[0;36m_normalize_index\u001b[0;34m(indexer, index)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m np\u001b[38;5;241m.\u001b[39many(positions \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 98\u001b[0m not_found \u001b[38;5;241m=\u001b[39m indexer[positions \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m---> 99\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(\n\u001b[1;32m 100\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mValues \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(not_found)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlist\u001b[39m(indexer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 101\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mare not valid obs/ var names or indices.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 102\u001b[0m )\n\u001b[1;32m 103\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m positions \u001b[38;5;66;03m# np.ndarray[int]\u001b[39;00m\n\u001b[1;32m 104\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown indexer \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mindexer\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m of type \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(indexer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mKeyError\u001b[0m: \"Values ['rna', 'protein'], from ['rna', 'protein'], are not valid obs/ var names or indices.\"" + ] + } + ], + "source": [ + "model_path = self._minify_and_save_model(model, mdata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "hub_model = self._create_hub_model(model_path)\n", + "hub_model = self._upload_hub_model(hub_model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "scvi-tools", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}