Skip to content

Commit

Permalink
Revision done.
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Dec 12, 2024
1 parent 84bc539 commit 7a5886e
Show file tree
Hide file tree
Showing 15 changed files with 821 additions and 174 deletions.
12 changes: 10 additions & 2 deletions src/scvi_hub_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
@click.option("--dry_run", type=bool, default=False, help="Dry run the workflow.")
@click.option("--config_key", type=str, help="Use a different config file, e.g. for test purpose.")
@click.option("--save_dir", type=str, help="Directory to save intermediate results (defaults temporary).")
def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_dir: str = None) -> None:
@click.option("--reload_data", type=bool, help="Reload the data or get from DVC.")
@click.option("--reload_model", type=bool, help="Reload the model or get from DVC.")
def run_workflow(
model_name: str,
dry_run: bool,
config_key: str = None,
save_dir: str = None,
reload_data: bool = False,
reload_model: bool = False) -> None:
"""Run the workflow for a specific model."""
from importlib import import_module
if not config_key:
Expand All @@ -22,7 +30,7 @@ def run_workflow(model_name: str, dry_run: bool, config_key: str = None, save_di
Workflow = workflow_module._Workflow
config = json_data_store[config_key]

workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config)
workflow = Workflow(save_dir=save_dir, dry_run=dry_run, config=config, reload_data=reload_data, reload_model=reload_model)
workflow.run()


Expand Down
5 changes: 4 additions & 1 deletion src/scvi_hub_models/config/haniffa_covid_pbmc.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"model_dir": "haniffa_covid_pbmc",
"model_class": "TOTALVI",
"repo_name": "scvi-tools/haniffa_covid_pbmc_totalvi",
"reload_data": true,
"extra_data_kwargs": {
"reference_adata_cxg_id": "c7775e88-49bf-4ba2-a03b-93f00447c958",
"reference_adata_fname": "haniffa_covid_pbmc.h5ad"
"reference_adata_fname": "haniffa_covid_pbmc.h5ad",
"large_training_file_name": "haniffa_covid_pbmc.h5mu"
},
"metadata": {
"training_data_url": "https://datasets.cellxgene.cziscience.com/5ad66a4f-d619-4cb3-8015-a87c755647b3.h5ad",
Expand All @@ -16,6 +18,7 @@
"description": "CITE-seq to measure RNA and surface proteins in thymocytes from wild-type and T cell lineage-restricted mice to generate a comprehensive timeline of cell state for each T cell lineage.",
"references": "Steier, Z., Aylard, D.A., McIntyre, L.L. et al. Single-cell multiomic analysis of thymocyte development reveals drivers of CD4+ T cell and CD8+ T cell lineage commitment. Nat Immunol 24, 1579–1590 (2023). https://doi.org/10.1038/s41590-023-01584-0."
},

"criticism_settings": {
"n_samples": 3,
"cell_type_key": "cell_type"
Expand Down
4 changes: 4 additions & 0 deletions src/scvi_hub_models/config/heart_cell_atlas.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
"model_dir": "heart_cell_atlas_scvi",
"model_class": "SCVI",
"repo_name": "scvi-tools/heart-cell-atlas-scvi",
"extra_data_kwargs": {
"reference_adata_fname": "heart_cell_atlas.h5ad",
"large_training_file_name": "heart_cell_atlas.h5ad"
},
"metadata": {
"training_data_url": "https://www.heartcellatlas.org/#DataSources",
"tissues": ["heart"],
Expand Down
5 changes: 3 additions & 2 deletions src/scvi_hub_models/config/human_lung_cell_atlas.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
{
"model_dir": "hlca_scanvi_reference",
"model_dir": "hlca_reference_scanvi",
"model_class": "SCANVI",
"repo_name": "scvi-tools/human-lung-cell-atlas-scanvi",
"extra_data_kwargs": {
"legacy_model_url": "https://zenodo.org/records/7599104/files/HLCA_reference_model.zip",
"legacy_model_hash": "a7cd60f4342292b3cba54545bcd8a34decdc8e6b82163f009273d543e7e3910e",
"legacy_model_dir": "hlca_scanvi_reference_legacy",
"reference_adata_cxg_id": "066943a2-fdac-4b29-b348-40cede398e4e",
"reference_adata_fname": "hlca_core.h5ad"
"reference_adata_fname": "hlca_core.h5ad",
"large_training_file_name": "hlca_core.h5ad"
},
"metadata": {
"training_data_url": "https://cellxgene.cziscience.com/collections/6f6d381a-7701-4781-935c-db10d30de293",
Expand Down
7 changes: 3 additions & 4 deletions src/scvi_hub_models/config/mouse_thymus_cite.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
{
"model_dir": "mouse_thymus_cite",
"model_dir": "mouse_thymus_cite_totalvi",
"model_class": "TOTALVI",
"repo_name": "scvi-tools/mouse_thymus_totalvi",
"reload_data": true,
"repo_name": "scvi-tools/mouse_thymus_cite_totalvi",
"extra_data_kwargs": {
"reference_adata_cxg_id": "c14c54f8-85d8-45db-9de7-6ab572cc748a",
"reference_adata_fname": "thymus_cite.h5ad",
"reference_adata_fname": "mouse_thymus_cite.h5ad",
"large_training_file_name": "mouse_thymus_cite.h5mu"
},
"metadata": {
Expand Down
7 changes: 5 additions & 2 deletions src/scvi_hub_models/config/neurips_cite.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
{
"model_dir": "bone_marrow_cite",
"model_dir": "bone_marrow_cite_totalvi",
"model_class": "TOTALVI",
"repo_name": "scvi-tools/bone_marrow_cite_totalvi",
"extra_data_kwargs": {
"reference_adata_fname": "bmmc_cite.h5ad"
"url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz",
"hash": "b9b50fade9349719cba23c97c6515d3501a32ee3735fe95fe51221d2e8a5f361",
"reference_adata_fname": "bmmc_cite.h5ad.gz",
"large_training_file_name": "neurips_bone_marrow_cite.h5mu"
},
"metadata": {
"training_data_url": "https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE194122&format=file&file=GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed%2Eh5ad%2Egz",
Expand Down
4 changes: 3 additions & 1 deletion src/scvi_hub_models/config/test_scvi.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"model_dir": "test_scvi",
"model_class": "SCVI",
"repo_name": "scvi-tools/test-scvi",
"extra_data_kwargs": {
"large_training_file_name": "test_data.h5ad"
},
"collection_name": "test",
"minify_model": false,
"extra_data_kwargs": {},
"metadata": {
"tissues": ["synthetic"],
"data_modalities": ["rna"],
Expand Down
184 changes: 103 additions & 81 deletions src/scvi_hub_models/models/_base_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,9 @@
from scvi.hub import HubMetadata, HubModel, HubModelCardHelper
from scvi.model.base import BaseModelClass

import subprocess, os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive


def upload_gdrive(path):
remote_url = subprocess.check_output(["dvc", "remote", "list"], text=True).split()[-1]

# Extract the Google Drive folder ID from the remote URL
folder_id = remote_url.split("gdrive://")[-1]

# Check if DVC detected an update
if "modified" in subprocess.check_output(["dvc", "diff", "--json"], text=True):
GoogleAuth().LocalWebserverAuth()
drive = GoogleDrive(GoogleAuth())
file = drive.CreateFile({"title": os.path.basename(path), "parents": [{"id": folder_id}]})
file.SetContentFile(path)
file.Upload()
print(f"Uploaded {path} to Google Drive folder {folder_id}.")

# Specify your repository and target file
repo_path = "."

repo_path = os.path.abspath(Path(__file__).parent.parent.parent.parent)
dvc_repo = Repo(repo_path)
git_repo = git.Repo(repo_path)

Expand Down Expand Up @@ -67,17 +48,25 @@ class BaseModelWorkflow:
config
A :class:`~frozendict.frozendict` containing the configuration for the workflow. Can only
be set once.
reload_data
If ``True``, the data will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``.
reload_model
If ``True``, the model will be reloaded. Otherwise, it will be pulled from DVC. Defaults to ``False``.
"""

def __init__(
self,
save_dir: str | None = None,
dry_run: bool = False,
config: frozendict | None = None
config: frozendict | None = None,
reload_data: bool = True,
reload_model: bool = True,
):
self.save_dir = save_dir
self.dry_run = dry_run
self.config = config
self.reload_data = reload_data
self.reload_model = reload_model

@property
def save_dir(self):
Expand Down Expand Up @@ -114,57 +103,82 @@ def config(self, value: frozendict):
value = frozendict(value)
self._config = value

@property
def reload_data(self):
return self._reload_data

@reload_data.setter
def reload_data(self, value: bool):
if hasattr(self, "_reload_data"):
raise AttributeError("`reload_data` can only be set once.")
self._reload_data = value

@property
def reload_model(self):
return self._reload_model

@reload_model.setter
def reload_model(self, value: bool):
if hasattr(self, "_reload_model"):
raise AttributeError("`reload_model` can only be set once.")
self._reload_model = value

def get_adata(self) -> anndata.AnnData | None:
"""Download and load the dataset."""
logger.info("Loading dataset.")
if self.dry_run:
return None
if self.config['reload_data']:
path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name'])
adata = self.download_adata()
if self.reload_data:
path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name'])
print(path_file)
adata = self.download_adata(path_file)
dvc_repo.add(path_file)
git_repo.index.commit(f"Track {path_file} with DVC")
print(f"Pushing {path_file} to DVC remote...")
dvc_repo.push()
git_repo.remote().push()
upload_gdrive(path_file)
else:
path_file = os.path.join('data/', self.config['extra_data_kwargs']['large_training_file_name'])
path_file = os.path.join(f'{repo_path}/data/', self.config['extra_data_kwargs']['large_training_file_name'])
dvc_repo.pull([path_file])
if path_file.endswith(".h5mu"):
adata = mudata.read_h5mu(path_file)
else:
adata = anndata.read_h5ad(path_file)
return adata

def _get_adata(self, url: str, hash: str, file_path: str) -> str:
def get_model(self, adata) -> BaseModelClass | None:
"""Download and load the model."""
logger.info("Loading model.")
if self.dry_run:
return None
if self.reload_model:
path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir'])
model = self.load_model(adata)
model.save(path_file, overwrite=True, save_anndata=False)
dvc_repo.add(path_file)
git_repo.index.commit(f"Track {path_file} with DVC")
dvc_repo.push()
git_repo.remote().push()
else:
path_file = os.path.join(f'{repo_path}/data/', self.config['model_dir'])
dvc_repo.pull([path_file])
model = self.default_load_model(adata, self.config['model_class'], path_file)
return model

def _get_adata(self, url: str, hash: str, file_path: str, processor: str | None = None) -> str:
logger.info("Downloading and reading data.")
if self.dry_run:
return None

retrieve(
file_out = retrieve(
url=url,
known_hash=hash,
fname=file_path,
path=self.save_dir,
processor=None,
)
return anndata.read_h5ad(os.path.join(self.save_dir, file_path))

def _download_model(self, url: str, hash: str, file_path: str) -> str:
logger.info("Downloading model.")
if self.dry_run:
return None

return retrieve(
url=url, #
known_hash=hash, #config["adata_hashes"][tissue],
fname=file_path, # f"{tissue}_adata.h5ad"
path=self.save_dir,
processor=None,
processor=processor,
)
return anndata.read_h5ad(file_out)

def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str):
def default_load_model(self, adata: anndata.AnnData, model_name: str, model_path: str | None = None) -> BaseModelClass:
"""Load the model."""
logger.info("Loading model.")
if self.dry_run:
Expand All @@ -178,48 +192,20 @@ def _load_model(self, model_path: str, adata: anndata.AnnData, model_name: str):
elif model_name == "CondSCVI":
from scvi.model import CondSCVI
model_cls = CondSCVI
elif model_name == "TOTALVI":
from scvi.model import TOTALVI
model_cls = TOTALVI
elif model_name == "Stereoscope":
from scvi.external import RNAStereoscope
model_cls = RNAStereoscope
else:
raise ValueError(f"Model {model_name} not recognized.")

model = model_cls.load(os.path.join(self.save_dir, model_path), adata=adata)
if model_path is None:
model_path = os.path.join(self.save_dir, self.config["model_dir"])
model = model_cls.load(model_path, adata=adata)
return model

def _create_hub_model(
self,
model_path: str,
training_data_url: str | None = None
) -> HubModel | None:
logger.info("Creating the HubModel.")
if self.dry_run:
return None

if training_data_url is None:
training_data_url = self.config.get("training_data_url", None)

metadata = self.config["metadata"]
hub_metadata = HubMetadata.from_dir(
model_path,
anndata_version=anndata_version
)
model_card = HubModelCardHelper.from_dir(
model_path,
anndata_version=anndata_version,
license_info=metadata.get("license_info", "mit"),
data_modalities=metadata.get("data_modalities", None),
tissues=metadata.get("tissues", None),
data_is_annotated=metadata.get("data_is_annotated", False),
data_is_minified=metadata.get("data_is_minified", False),
training_data_url=training_data_url,
training_code_url=metadata.get("training_code_url", None),
description=metadata.get("description", None),
references=metadata.get("references", None),
)

return HubModel(model_path, hub_metadata, model_card)

def _minify_and_save_model(
self,
model: BaseModelClass,
Expand Down Expand Up @@ -252,11 +238,47 @@ def _minify_and_save_model(
qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True)
adata.obsm[qzm_key] = qzm
adata.obsm[qzv_key] = qzv
model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key)
if isinstance(adata, mudata.MuData):
model.minify_mudata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key)
else:
model.minify_adata(use_latent_qzm_key=qzm_key, use_latent_qzv_key=qzv_key)
model.save(mini_model_path, overwrite=True, save_anndata=True)

return mini_model_path

def _create_hub_model(
self,
model_path: str,
training_data_url: str | None = None
) -> HubModel | None:
logger.info("Creating the HubModel.")
if self.dry_run:
return None

if training_data_url is None:
training_data_url = self.config.get("training_data_url", None)

metadata = self.config["metadata"]
hub_metadata = HubMetadata.from_dir(
model_path,
anndata_version=anndata_version
)
model_card = HubModelCardHelper.from_dir(
model_path,
anndata_version=anndata_version,
license_info=metadata.get("license_info", "mit"),
data_modalities=metadata.get("data_modalities", None),
tissues=metadata.get("tissues", None),
data_is_annotated=metadata.get("data_is_annotated", False),
data_is_minified=metadata.get("data_is_minified", False),
training_data_url=training_data_url,
training_code_url=metadata.get("training_code_url", None),
description=metadata.get("description", None),
references=metadata.get("references", None),
)

return HubModel(model_path, hub_metadata, model_card)

def _upload_hub_model(self, hub_model: HubModel, repo_name: str | None = None, **kwargs) -> HubModel:
"""Upload the HubModel to Hugging Face."""
collection_name = self.config.get("collection_name", None)
Expand Down
Loading

0 comments on commit 7a5886e

Please sign in to comment.