diff --git a/CHANGELOG.md b/CHANGELOG.md index 3eb0ebe..2a22864 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ * Update `process_dataset` component to subsample large datasets (PR #14). +* Add the scPRINT method (PR #25) + ## MAJOR CHANGES * Revamp `scripts` directory (PR #13). diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index 55580c0..0d01b0f 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -15,13 +15,19 @@ echo " Make sure to run 'scripts/project/build_all_docker_containers.sh'!" RUN_ID="testrun_$(date +%Y-%m-%d_%H-%M-%S)" publish_dir="temp/results/${RUN_ID}" +# write the parameters to file +cat > /tmp/params.yaml << HERE +input_states: resources_test/task_denoising/**/state.yaml +rename_keys: 'input_train:train;input_test:test' +output_state: "state.yaml" +publish_dir: "$publish_dir" +settings: '{"methods_exclude": ["scprint"]}' +HERE + nextflow run . \ -main-script target/nextflow/workflows/run_benchmark/main.nf \ -profile docker \ -resume \ + -entry auto \ -c common/nextflow_helpers/labels_ci.config \ - --id cxg_immune_cell_atlas \ - --input_train resources_test/task_denoising/cxg_immune_cell_atlas/train.h5ad \ - --input_test resources_test/task_denoising/cxg_immune_cell_atlas/test.h5ad \ - --output_state state.yaml \ - --publish_dir "$publish_dir" + -params-file /tmp/params.yaml diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml new file mode 100644 index 0000000..2886ecb --- /dev/null +++ b/src/methods/scprint/config.vsh.yaml @@ -0,0 +1,77 @@ +__merge__: /src/api/base_method.yaml + +name: scprint +label: scPRINT +summary: scPRINT is a large transformer model built for the inference of gene networks +description: | + scPRINT is a large transformer model built for the inference of gene networks + (connections between genes explaining the cell's expression profile) from + scRNAseq data. + + It uses novel encoding and decoding of the cell expression profile and new + pre-training methodologies to learn a cell model. + + scPRINT can be used to perform the following analyses: + + - expression denoising: increase the resolution of your scRNAseq data + - cell embedding: generate a low-dimensional representation of your dataset + - label prediction: predict the cell type, disease, sequencer, sex, and + ethnicity of your cells + - gene network inference: generate a gene network from any cell or cell + cluster in your scRNAseq dataset + +references: + doi: + - 10.1101/2024.07.29.605556 + +links: + documentation: https://cantinilab.github.io/scPRINT/ + repository: https://github.com/cantinilab/scPRINT + +info: + preferred_normalization: counts + variants: + scprint_large: + model_name: "large" + scprint_medium: + model_name: "medium" + scprint_small: + model_name: "small" + +arguments: + - name: "--model_name" + type: "string" + description: Which model to use. Not used if --model is provided. + choices: ["large", "medium", "small"] + default: "large" + - name: --model + type: file + description: Path to the scPRINT model. + required: false + +resources: + - type: python_script + path: script.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pip: + - huggingface_hub + - scprint + - type: docker + run: lamin init --storage ./main --name main --schema bionty + - type: python + script: import bionty as bt; bt.core.sync_all_sources_to_latest() + - type: docker + run: lamin load anonymous/main + - type: python + script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py new file mode 100644 index 0000000..e5f4c4a --- /dev/null +++ b/src/methods/scprint/script.py @@ -0,0 +1,115 @@ +import os + +import anndata as ad +import scprint +import torch +from huggingface_hub import hf_hub_download +from scdataloader import Preprocessor +from scprint import scPrint +from scprint.tasks import Denoiser +import numpy as np + +## VIASH START +par = { + "input_train": "resources_test/task_batch_integration/cxg_immune_cell_atlas/train.h5ad", + "output": "output.h5ad", + "model_name": "large", + "model": None, +} +meta = {"name": "scprint"} +## VIASH END + +print(f"====== scPRINT version {scprint.__version__} ======", flush=True) + +print("\n>>> Reading input data...", flush=True) +input = ad.read_h5ad(par["input_train"]) +print(input) + +print("\n>>> Preprocessing data...", flush=True) +adata = ad.AnnData( + X=input.layers["counts"] +) +adata.obs_names = input.obs_names +adata.var_names = input.var_names +if input.uns["dataset_organism"] == "homo_sapiens": + adata.obs["organism_ontology_term_id"] = "NCBITaxon:9606" +elif input.uns["dataset_organism"] == "mus_musculus": + adata.obs["organism_ontology_term_id"] = "NCBITaxon:10090" +else: + raise ValueError( + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" + ) + +preprocessor = Preprocessor( + # Lower this threshold for test datasets + min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, + # Turn off cell filtering to return results for all cells + filter_cell_by_counts=False, + min_nnz_genes=False, + do_postp=False, + # Skip ontology checks + skip_validate=True, +) +adata = preprocessor(adata) +print(adata) + +model_checkpoint_file = par["model"] +if model_checkpoint_file is None: + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) + model_checkpoint_file = hf_hub_download( + repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" + ) +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) +model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer="normal", # Don't use this for GPUs with flashattention + precpt_gene_emb=None, +) + +print("\n>>> Denoising data...", flush=True) +if torch.cuda.is_available(): + print("CUDA is available, using GPU", flush=True) + precision = "16-mixed" + dtype = torch.float16 +else: + print("CUDA is not available, using CPU", flush=True) + precision = "32" + dtype = torch.float32 +n_cores_available = len(os.sched_getaffinity(0)) +print(f"Using {n_cores_available} worker cores") +denoiser = Denoiser( + num_workers=n_cores_available, + precision=precision, + max_cells=adata.n_obs, + doplot=False, + dtype=dtype, +) +_, idxs, genes, expr_pred = denoiser(model, adata) +print(f"Predicted expression dimensions: {expr_pred.shape}") + +print("\n>>> Applying denoising...", flush=True) +adata.X = adata.X.tolil() +idxs = idxs if idxs is not None else range(adata.shape[0]) +for i, idx in enumerate(idxs): + adata.X[idx, adata.var.index.get_indexer(genes)] = expr_pred[i] +adata.X = adata.X.tocsr() +print(adata) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + layers={ + "denoised": adata.X[:, adata.var.index.get_indexer(input.var_names)], + }, + obs=input.obs[[]], + var=input.var[[]], + uns={ + "dataset_id": input.uns["dataset_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output AnnData to file...", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 083dd30..6850e1d 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -75,6 +75,7 @@ dependencies: - name: methods/dca - name: methods/knn_smoothing - name: methods/magic + - name: methods/scprint - name: metrics/mse - name: metrics/poisson runners: diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 97155fb..725ac72 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -20,7 +20,8 @@ workflow run_wf { alra, dca, knn_smoothing, - magic + magic, + scprint ] // construct list of metrics