generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update common submodule * Use checkItemAllowed() for benchmark method check * Replace cxg_mouse_pancreas_atlas with cxg_immune_cell_atlas * Update README * Update CHANGELOG * Add a base method API schema * Update CHANGELOG * Add config check to base method schema * Add dataset_organism to training dataset files * Add scPRINT component * Adapt scPRINT for denoising task * Add scPRINT to benchmark workflow * Update CHANGELOG
- Loading branch information
Showing
6 changed files
with
208 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
__merge__: /src/api/base_method.yaml | ||
|
||
name: scprint | ||
label: scPRINT | ||
summary: scPRINT is a large transformer model built for the inference of gene networks | ||
description: | | ||
scPRINT is a large transformer model built for the inference of gene networks | ||
(connections between genes explaining the cell's expression profile) from | ||
scRNAseq data. | ||
It uses novel encoding and decoding of the cell expression profile and new | ||
pre-training methodologies to learn a cell model. | ||
scPRINT can be used to perform the following analyses: | ||
- expression denoising: increase the resolution of your scRNAseq data | ||
- cell embedding: generate a low-dimensional representation of your dataset | ||
- label prediction: predict the cell type, disease, sequencer, sex, and | ||
ethnicity of your cells | ||
- gene network inference: generate a gene network from any cell or cell | ||
cluster in your scRNAseq dataset | ||
references: | ||
doi: | ||
- 10.1101/2024.07.29.605556 | ||
|
||
links: | ||
documentation: https://cantinilab.github.io/scPRINT/ | ||
repository: https://github.com/cantinilab/scPRINT | ||
|
||
info: | ||
preferred_normalization: counts | ||
variants: | ||
scprint_large: | ||
model_name: "large" | ||
scprint_medium: | ||
model_name: "medium" | ||
scprint_small: | ||
model_name: "small" | ||
|
||
arguments: | ||
- name: "--model_name" | ||
type: "string" | ||
description: Which model to use. Not used if --model is provided. | ||
choices: ["large", "medium", "small"] | ||
default: "large" | ||
- name: --model | ||
type: file | ||
description: Path to the scPRINT model. | ||
required: false | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
setup: | ||
- type: python | ||
pip: | ||
- huggingface_hub | ||
- scprint | ||
- type: docker | ||
run: lamin init --storage ./main --name main --schema bionty | ||
- type: python | ||
script: import bionty as bt; bt.core.sync_all_sources_to_latest() | ||
- type: docker | ||
run: lamin load anonymous/main | ||
- type: python | ||
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology() | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime, midmem, midcpu, gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
|
||
import anndata as ad | ||
import scprint | ||
import torch | ||
from huggingface_hub import hf_hub_download | ||
from scdataloader import Preprocessor | ||
from scprint import scPrint | ||
from scprint.tasks import Denoiser | ||
import numpy as np | ||
|
||
## VIASH START | ||
par = { | ||
"input_train": "resources_test/task_batch_integration/cxg_immune_cell_atlas/train.h5ad", | ||
"output": "output.h5ad", | ||
"model_name": "large", | ||
"model": None, | ||
} | ||
meta = {"name": "scprint"} | ||
## VIASH END | ||
|
||
print(f"====== scPRINT version {scprint.__version__} ======", flush=True) | ||
|
||
print("\n>>> Reading input data...", flush=True) | ||
input = ad.read_h5ad(par["input_train"]) | ||
print(input) | ||
|
||
print("\n>>> Preprocessing data...", flush=True) | ||
adata = ad.AnnData( | ||
X=input.layers["counts"] | ||
) | ||
adata.obs_names = input.obs_names | ||
adata.var_names = input.var_names | ||
if input.uns["dataset_organism"] == "homo_sapiens": | ||
adata.obs["organism_ontology_term_id"] = "NCBITaxon:9606" | ||
elif input.uns["dataset_organism"] == "mus_musculus": | ||
adata.obs["organism_ontology_term_id"] = "NCBITaxon:10090" | ||
else: | ||
raise ValueError( | ||
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'" | ||
) | ||
|
||
preprocessor = Preprocessor( | ||
# Lower this threshold for test datasets | ||
min_valid_genes_id=1000 if input.n_vars < 2000 else 10000, | ||
# Turn off cell filtering to return results for all cells | ||
filter_cell_by_counts=False, | ||
min_nnz_genes=False, | ||
do_postp=False, | ||
# Skip ontology checks | ||
skip_validate=True, | ||
) | ||
adata = preprocessor(adata) | ||
print(adata) | ||
|
||
model_checkpoint_file = par["model"] | ||
if model_checkpoint_file is None: | ||
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) | ||
model_checkpoint_file = hf_hub_download( | ||
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt" | ||
) | ||
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True) | ||
model = scPrint.load_from_checkpoint( | ||
model_checkpoint_file, | ||
transformer="normal", # Don't use this for GPUs with flashattention | ||
precpt_gene_emb=None, | ||
) | ||
|
||
print("\n>>> Denoising data...", flush=True) | ||
if torch.cuda.is_available(): | ||
print("CUDA is available, using GPU", flush=True) | ||
precision = "16-mixed" | ||
dtype = torch.float16 | ||
else: | ||
print("CUDA is not available, using CPU", flush=True) | ||
precision = "32" | ||
dtype = torch.float32 | ||
n_cores_available = len(os.sched_getaffinity(0)) | ||
print(f"Using {n_cores_available} worker cores") | ||
denoiser = Denoiser( | ||
num_workers=n_cores_available, | ||
precision=precision, | ||
max_cells=adata.n_obs, | ||
doplot=False, | ||
dtype=dtype, | ||
) | ||
_, idxs, genes, expr_pred = denoiser(model, adata) | ||
print(f"Predicted expression dimensions: {expr_pred.shape}") | ||
|
||
print("\n>>> Applying denoising...", flush=True) | ||
adata.X = adata.X.tolil() | ||
idxs = idxs if idxs is not None else range(adata.shape[0]) | ||
for i, idx in enumerate(idxs): | ||
adata.X[idx, adata.var.index.get_indexer(genes)] = expr_pred[i] | ||
adata.X = adata.X.tocsr() | ||
print(adata) | ||
|
||
print("\n>>> Storing output...", flush=True) | ||
output = ad.AnnData( | ||
layers={ | ||
"denoised": adata.X[:, adata.var.index.get_indexer(input.var_names)], | ||
}, | ||
obs=input.obs[[]], | ||
var=input.var[[]], | ||
uns={ | ||
"dataset_id": input.uns["dataset_id"], | ||
"method_id": meta["name"], | ||
}, | ||
) | ||
print(output) | ||
|
||
print("\n>>> Writing output AnnData to file...", flush=True) | ||
output.write_h5ad(par["output"], compression="gzip") | ||
|
||
print("\n>>> Done!", flush=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,8 @@ workflow run_wf { | |
alra, | ||
dca, | ||
knn_smoothing, | ||
magic | ||
magic, | ||
scprint | ||
] | ||
|
||
// construct list of metrics | ||
|