Skip to content

Commit

Permalink
Add scPRINT method (#13)
Browse files Browse the repository at this point in the history
* Add scPRINT component files

* Load and preprocess data for scPRINT

* Try running model...

* Adjust scPRINT installation

* Embed and save scPRINT output

* Detect available cores

* Adjust arguments if GPU available

* Add model argument to scPRINT

* Add scPRINT to benchmark workflow

* Make scPRINT inherit from base method

Model is too large for CI tests

* style code

* Apply suggestions from code review

Co-authored-by: Robrecht Cannoodt <[email protected]>

* Remove test workflow file

* Fix test data path

---------

Co-authored-by: Robrecht Cannoodt <[email protected]>
  • Loading branch information
lazappi and rcannood authored Nov 26, 2024
1 parent 6dc0f1d commit 52ccedb
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 0 deletions.
79 changes: 79 additions & 0 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
__merge__: /src/api/base_method.yaml

name: scprint
label: scPRINT
summary: scPRINT is a large transformer model built for the inference of gene networks
description: |
scPRINT is a large transformer model built for the inference of gene networks
(connections between genes explaining the cell's expression profile) from
scRNAseq data.
It uses novel encoding and decoding of the cell expression profile and new
pre-training methodologies to learn a cell model.
scPRINT can be used to perform the following analyses:
- expression denoising: increase the resolution of your scRNAseq data
- cell embedding: generate a low-dimensional representation of your dataset
- label prediction: predict the cell type, disease, sequencer, sex, and
ethnicity of your cells
- gene network inference: generate a gene network from any cell or cell
cluster in your scRNAseq dataset
references:
doi:
- 10.1101/2024.07.29.605556

links:
documentation: https://cantinilab.github.io/scPRINT/
repository: https://github.com/cantinilab/scPRINT

info:
preferred_normalization: counts
method_types: [embedding]
variants:
scprint_large:
model_name: "large"
scprint_medium:
model_name: "medium"
scprint_small:
model_name: "small"

arguments:
- name: "--model_name"
type: "string"
description: Which model to use. Not used if --model is provided.
choices: ["large", "medium", "small"]
default: "large"
- name: --model
type: file
description: Path to the scPRINT model.
required: false

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- huggingface_hub
- scprint
- type: docker
run: lamin init --storage ./main --name main --schema bionty
- type: python
script: import bionty as bt; bt.core.sync_all_sources_to_latest()
- type: docker
run: lamin load anonymous/main
- type: python
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
105 changes: 105 additions & 0 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import anndata as ad
from scdataloader import Preprocessor
import sys
from huggingface_hub import hf_hub_download
from scprint.tasks import Embedder
from scprint import scPrint
import scprint
import torch
import os

## VIASH START
par = {
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
"output": "output.h5ad",
"model_name": "large",
"model": None,
}
meta = {"name": "scprint"}
## VIASH END

sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

print(f"====== scPRINT version {scprint.__version__} ======", flush=True)

print("\n>>> Reading input data...", flush=True)
input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")
if input.uns["dataset_organism"] == "homo_sapiens":
input.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
elif input.uns["dataset_organism"] == "mus_musculus":
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
else:
raise ValueError(
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
)
adata = input.copy()

print("\n>>> Preprocessing data...", flush=True)
preprocessor = Preprocessor(
# Lower this threshold for test datasets
min_valid_genes_id=1000 if input.n_vars < 2000 else 10000,
# Turn off cell filtering to return results for all cells
filter_cell_by_counts=False,
min_nnz_genes=False,
do_postp=False,
# Skip ontology checks
skip_validate=True,
)
adata = preprocessor(adata)

model_checkpoint_file = par["model"]
if model_checkpoint_file is None:
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
model_checkpoint_file = hf_hub_download(
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
)
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer="normal", # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
)

print("\n>>> Embedding data...", flush=True)
if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16"
dtype = torch.float16
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
n_cores_available = len(os.sched_getaffinity(0))
print(f"Using {n_cores_available} worker cores")
embedder = Embedder(
how="random expr",
max_len=4000,
add_zero_genes=0,
num_workers=n_cores_available,
doclass=False,
doplot=False,
precision=precision,
dtype=dtype,
)
embedded, _ = embedder(model, adata, cache=False)

print("\n>>> Storing output...", flush=True)
output = ad.AnnData(
obs=input.obs[[]],
var=input.var[[]],
obsm={
"X_emb": embedded.obsm["scprint"],
},
uns={
"dataset_id": input.uns["dataset_id"],
"normalization_id": input.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Done!", flush=True)
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ dependencies:
- name: methods/scanvi
- name: methods/scgpt
- name: methods/scimilarity
- name: methods/scprint
- name: methods/scvi
- name: methods/uce
# metrics
Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ methods = [
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
scprint,
scvi,
uce.run(
args: [model: file("s3://openproblems-work/cache/uce-model-v5.zip")]
Expand Down

0 comments on commit 52ccedb

Please sign in to comment.