generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add scPRINT component files * Load and preprocess data for scPRINT * Try running model... * Adjust scPRINT installation * Embed and save scPRINT output * Detect available cores * Adjust arguments if GPU available * Add model argument to scPRINT * Add scPRINT to benchmark workflow * Make scPRINT inherit from base method Model is too large for CI tests * style code * Apply suggestions from code review Co-authored-by: Robrecht Cannoodt <[email protected]> * Remove test workflow file * Fix test data path --------- Co-authored-by: Robrecht Cannoodt <[email protected]>
- Loading branch information
Showing
4 changed files
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
__merge__: /src/api/base_method.yaml | ||
|
||
name: scprint | ||
label: scPRINT | ||
summary: scPRINT is a large transformer model built for the inference of gene networks | ||
description: | | ||
scPRINT is a large transformer model built for the inference of gene networks | ||
(connections between genes explaining the cell's expression profile) from | ||
scRNAseq data. | ||
It uses novel encoding and decoding of the cell expression profile and new | ||
pre-training methodologies to learn a cell model. | ||
scPRINT can be used to perform the following analyses: | ||
- expression denoising: increase the resolution of your scRNAseq data | ||
- cell embedding: generate a low-dimensional representation of your dataset | ||
- label prediction: predict the cell type, disease, sequencer, sex, and | ||
ethnicity of your cells | ||
- gene network inference: generate a gene network from any cell or cell | ||
cluster in your scRNAseq dataset | ||
references: | ||
doi: | ||
- 10.1101/2024.07.29.605556 | ||
|
||
links: | ||
documentation: https://cantinilab.github.io/scPRINT/ | ||
repository: https://github.com/cantinilab/scPRINT | ||
|
||
info: | ||
preferred_normalization: counts | ||
method_types: [embedding] | ||
variants: | ||
scprint_large: | ||
model_name: "large" | ||
scprint_medium: | ||
model_name: "medium" | ||
scprint_small: | ||
model_name: "small" | ||
|
||
arguments: | ||
- name: "--model_name" | ||
type: "string" | ||
description: Which model to use. Not used if --model is provided. | ||
choices: ["large", "medium", "small"] | ||
default: "large" | ||
- name: --model | ||
type: file | ||
description: Path to the scPRINT model. | ||
required: false | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: /src/utils/read_anndata_partial.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
setup: | ||
- type: python | ||
pip: | ||
- huggingface_hub | ||
- scprint | ||
- type: docker | ||
run: lamin init --storage ./main --name main --schema bionty | ||
- type: python | ||
script: import bionty as bt; bt.core.sync_all_sources_to_latest() | ||
- type: docker | ||
run: lamin load anonymous/main | ||
- type: python | ||
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime, midmem, midcpu, gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import anndata as ad | ||
from scdataloader import Preprocessor | ||
import sys | ||
from huggingface_hub import hf_hub_download | ||
from scprint.tasks import Embedder | ||
from scprint import scPrint | ||
import scprint | ||
import torch | ||
import os | ||
|
||
## VIASH START | ||
par = { | ||
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", | ||
"output": "output.h5ad", | ||
"model_name": "large", | ||
"model": None, | ||
} | ||
meta = {"name": "scprint"} | ||
## VIASH END | ||
|
||
sys.path.append(meta["resources_dir"]) | ||
from read_anndata_partial import read_anndata | ||
|
||
print(f"====== scPRINT version {scprint.__version__} ======", flush=True) | ||
|
||
print("\n>>> Reading input data...", flush=True) | ||
input = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") | ||
if input.uns["dataset_organism"] == "homo_sapiens": | ||
input.obs["organism_ontology_term_id"] = "NCBITaxon:9606" | ||
elif input.uns["dataset_organism"] == "mus_musculus": | ||
input.obs["organism_ontology_term_id"] = "NCBITaxon:10090" | ||
else: | ||
raise ValueError( | ||
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" | ||
) | ||
adata = input.copy() | ||
|
||
print("\n>>> Preprocessing data...", flush=True) | ||
preprocessor = Preprocessor( | ||
# Lower this threshold for test datasets | ||
min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, | ||
# Turn off cell filtering to return results for all cells | ||
filter_cell_by_counts=False, | ||
min_nnz_genes=False, | ||
do_postp=False, | ||
# Skip ontology checks | ||
skip_validate=True, | ||
) | ||
adata = preprocessor(adata) | ||
|
||
model_checkpoint_file = par["model"] | ||
if model_checkpoint_file is None: | ||
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) | ||
model_checkpoint_file = hf_hub_download( | ||
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" | ||
) | ||
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) | ||
model = scPrint.load_from_checkpoint( | ||
model_checkpoint_file, | ||
transformer="normal", # Don't use this for GPUs with flashattention | ||
precpt_gene_emb=None, | ||
) | ||
|
||
print("\n>>> Embedding data...", flush=True) | ||
if torch.cuda.is_available(): | ||
print("CUDA is available, using GPU", flush=True) | ||
precision = "16" | ||
dtype = torch.float16 | ||
else: | ||
print("CUDA is not available, using CPU", flush=True) | ||
precision = "32" | ||
dtype = torch.float32 | ||
n_cores_available = len(os.sched_getaffinity(0)) | ||
print(f"Using {n_cores_available} worker cores") | ||
embedder = Embedder( | ||
how="random expr", | ||
max_len=4000, | ||
add_zero_genes=0, | ||
num_workers=n_cores_available, | ||
doclass=False, | ||
doplot=False, | ||
precision=precision, | ||
dtype=dtype, | ||
) | ||
embedded, _ = embedder(model, adata, cache=False) | ||
|
||
print("\n>>> Storing output...", flush=True) | ||
output = ad.AnnData( | ||
obs=input.obs[[]], | ||
var=input.var[[]], | ||
obsm={ | ||
"X_emb": embedded.obsm["scprint"], | ||
}, | ||
uns={ | ||
"dataset_id": input.uns["dataset_id"], | ||
"normalization_id": input.uns["normalization_id"], | ||
"method_id": meta["name"], | ||
}, | ||
) | ||
print(output) | ||
|
||
print("\n>>> Writing output AnnData to file...", flush=True) | ||
output.write_h5ad(par["output"], compression="gzip") | ||
|
||
print("\n>>> Done!", flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters