From 52ccedbd0221be2af27c3e9dd584a7ae63e8a0a1 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 26 Nov 2024 14:24:14 +0100 Subject: [PATCH] Add scPRINT method (#13) * Add scPRINT component files * Load and preprocess data for scPRINT * Try running model... * Adjust scPRINT installation * Embed and save scPRINT output * Detect available cores * Adjust arguments if GPU available * Add model argument to scPRINT * Add scPRINT to benchmark workflow * Make scPRINT inherit from base method Model is too large for CI tests * style code * Apply suggestions from code review Co-authored-by: Robrecht Cannoodt * Remove test workflow file * Fix test data path --------- Co-authored-by: Robrecht Cannoodt --- src/methods/scprint/config.vsh.yaml | 79 +++++++++++++++ src/methods/scprint/script.py | 105 ++++++++++++++++++++ src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 1 + 4 files changed, 186 insertions(+) create mode 100644 src/methods/scprint/config.vsh.yaml create mode 100644 src/methods/scprint/script.py diff --git a/src/methods/scprint/config.vsh.yaml b/src/methods/scprint/config.vsh.yaml new file mode 100644 index 0000000..12100df --- /dev/null +++ b/src/methods/scprint/config.vsh.yaml @@ -0,0 +1,79 @@ +__merge__: /src/api/base_method.yaml + +name: scprint +label: scPRINT +summary: scPRINT is a large transformer model built for the inference of gene networks +description: | + scPRINT is a large transformer model built for the inference of gene networks + (connections between genes explaining the cell's expression profile) from + scRNAseq data. + + It uses novel encoding and decoding of the cell expression profile and new + pre-training methodologies to learn a cell model. + + scPRINT can be used to perform the following analyses: + + - expression denoising: increase the resolution of your scRNAseq data + - cell embedding: generate a low-dimensional representation of your dataset + - label prediction: predict the cell type, disease, sequencer, sex, and + ethnicity of your cells + - gene network inference: generate a gene network from any cell or cell + cluster in your scRNAseq dataset + +references: + doi: + - 10.1101/2024.07.29.605556 + +links: + documentation: https://cantinilab.github.io/scPRINT/ + repository: https://github.com/cantinilab/scPRINT + +info: + preferred_normalization: counts + method_types: [embedding] + variants: + scprint_large: + model_name: "large" + scprint_medium: + model_name: "medium" + scprint_small: + model_name: "small" + +arguments: + - name: "--model_name" + type: "string" + description: Which model to use. Not used if --model is provided. + choices: ["large", "medium", "small"] + default: "large" + - name: --model + type: file + description: Path to the scPRINT model. + required: false + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pip: + - huggingface_hub + - scprint + - type: docker + run: lamin init --storage ./main --name main --schema bionty + - type: python + script: import bionty as bt; bt.core.sync_all_sources_to_latest() + - type: docker + run: lamin load anonymous/main + - type: python + script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/scprint/script.py b/src/methods/scprint/script.py new file mode 100644 index 0000000..6c1d6b9 --- /dev/null +++ b/src/methods/scprint/script.py @@ -0,0 +1,105 @@ +import anndata as ad +from scdataloader import Preprocessor +import sys +from huggingface_hub import hf_hub_download +from scprint.tasks import Embedder +from scprint import scPrint +import scprint +import torch +import os + +## VIASH START +par = { + "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", + "output": "output.h5ad", + "model_name": "large", + "model": None, +} +meta = {"name": "scprint"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +print(f"====== scPRINT version {scprint.__version__} ======", flush=True) + +print("\n>>> Reading input data...", flush=True) +input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") +if input.uns["dataset_organism"] == "homo_sapiens": + input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" +elif input.uns["dataset_organism"] == "mus_musculus": + input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" +else: + raise ValueError( + f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" + ) +adata = input.copy() + +print("\n>>> Preprocessing data...", flush=True) +preprocessor = Preprocessor( + # Lower this threshold for test datasets + min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, + # Turn off cell filtering to return results for all cells + filter_cell_by_counts=False, + min_nnz_genes=False, + do_postp=False, + # Skip ontology checks + skip_validate=True, +) +adata = preprocessor(adata) + +model_checkpoint_file = par["model"] +if model_checkpoint_file is None: + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) + model_checkpoint_file = hf_hub_download( + repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" + ) +print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) +model = scPrint.load_from_checkpoint( + model_checkpoint_file, + transformer="normal", # Don't use this for GPUs with flashattention + precpt_gene_emb=None, +) + +print("\n>>> Embedding data...", flush=True) +if torch.cuda.is_available(): + print("CUDA is available, using GPU", flush=True) + precision = "16" + dtype = torch.float16 +else: + print("CUDA is not available, using CPU", flush=True) + precision = "32" + dtype = torch.float32 +n_cores_available = len(os.sched_getaffinity(0)) +print(f"Using {n_cores_available} worker cores") +embedder = Embedder( + how="random expr", + max_len=4000, + add_zero_genes=0, + num_workers=n_cores_available, + doclass=False, + doplot=False, + precision=precision, + dtype=dtype, +) +embedded, _ = embedder(model, adata, cache=False) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=input.obs[[]], + var=input.var[[]], + obsm={ + "X_emb": embedded.obsm["scprint"], + }, + uns={ + "dataset_id": input.uns["dataset_id"], + "normalization_id": input.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output AnnData to file...", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 1047ea8..75c9300 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -96,6 +96,7 @@ dependencies: - name: methods/scanvi - name: methods/scgpt - name: methods/scimilarity + - name: methods/scprint - name: methods/scvi - name: methods/uce # metrics diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index aaacb43..afcb968 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -33,6 +33,7 @@ methods = [ scimilarity.run( args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ), + scprint, scvi, uce.run( args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")]