Skip to content

Commit

Permalink
Add scPRINT method (#25)
Browse files Browse the repository at this point in the history
* Update common submodule

* Use checkItemAllowed() for benchmark method check

* Replace cxg_mouse_pancreas_atlas with cxg_immune_cell_atlas

* Update README

* Update CHANGELOG

* Add a base method API schema

* Update CHANGELOG

* Add config check to base method schema

* Add dataset_organism to training dataset files

* Add scPRINT component

* Adapt scPRINT for denoising task

* Add scPRINT to benchmark workflow

* Update CHANGELOG
  • Loading branch information
lazappi authored Dec 19, 2024
1 parent 9c77313 commit 252731b
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

* Update `process_dataset` component to subsample large datasets (PR #14).

* Add the scPRINT method (PR #25)

## MAJOR CHANGES

* Revamp `scripts` directory (PR #13).
Expand Down
16 changes: 11 additions & 5 deletions scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,19 @@ echo " Make sure to run 'scripts/project/build_all_docker_containers.sh'!"
RUN_ID="testrun_$(date +%Y-%m-%d_%H-%M-%S)"
publish_dir="temp/results/${RUN_ID}"

# write the parameters to file
cat > /tmp/params.yaml << HERE
input_states: resources_test/task_denoising/**/state.yaml
rename_keys: 'input_train:train;input_test:test'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["scprint"]}'
HERE

nextflow run . \
-main-script target/nextflow/workflows/run_benchmark/main.nf \
-profile docker \
-resume \
-entry auto \
-c common/nextflow_helpers/labels_ci.config \
--id cxg_immune_cell_atlas \
--input_train resources_test/task_denoising/cxg_immune_cell_atlas/train.h5ad \
--input_test resources_test/task_denoising/cxg_immune_cell_atlas/test.h5ad \
--output_state state.yaml \
--publish_dir "$publish_dir"
-params-file /tmp/params.yaml
77 changes: 77 additions & 0 deletions src/methods/scprint/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
__merge__: /src/api/base_method.yaml

name: scprint
label: scPRINT
summary: scPRINT is a large transformer model built for the inference of gene networks
description: |
scPRINT is a large transformer model built for the inference of gene networks
(connections between genes explaining the cell's expression profile) from
scRNAseq data.
It uses novel encoding and decoding of the cell expression profile and new
pre-training methodologies to learn a cell model.
scPRINT can be used to perform the following analyses:
- expression denoising: increase the resolution of your scRNAseq data
- cell embedding: generate a low-dimensional representation of your dataset
- label prediction: predict the cell type, disease, sequencer, sex, and
ethnicity of your cells
- gene network inference: generate a gene network from any cell or cell
cluster in your scRNAseq dataset
references:
doi:
- 10.1101/2024.07.29.605556

links:
documentation: https://cantinilab.github.io/scPRINT/
repository: https://github.com/cantinilab/scPRINT

info:
preferred_normalization: counts
variants:
scprint_large:
model_name: "large"
scprint_medium:
model_name: "medium"
scprint_small:
model_name: "small"

arguments:
- name: "--model_name"
type: "string"
description: Which model to use. Not used if --model is provided.
choices: ["large", "medium", "small"]
default: "large"
- name: --model
type: file
description: Path to the scPRINT model.
required: false

resources:
- type: python_script
path: script.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- huggingface_hub
- scprint
- type: docker
run: lamin init --storage ./main --name main --schema bionty
- type: python
script: import bionty as bt; bt.core.sync_all_sources_to_latest()
- type: docker
run: lamin load anonymous/main
- type: python
script: from scdataloader.utils import populate_my_ontology; populate_my_ontology()

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
115 changes: 115 additions & 0 deletions src/methods/scprint/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os

import anndata as ad
import scprint
import torch
from huggingface_hub import hf_hub_download
from scdataloader import Preprocessor
from scprint import scPrint
from scprint.tasks import Denoiser
import numpy as np

## VIASH START
par = {
"input_train": "resources_test/task_batch_integration/cxg_immune_cell_atlas/train.h5ad",
"output": "output.h5ad",
"model_name": "large",
"model": None,
}
meta = {"name": "scprint"}
## VIASH END

print(f"====== scPRINT version {scprint.__version__} ======", flush=True)

print("\n>>> Reading input data...", flush=True)
input = ad.read_h5ad(par["input_train"])
print(input)

print("\n>>> Preprocessing data...", flush=True)
adata = ad.AnnData(
X=input.layers["counts"]
)
adata.obs_names = input.obs_names
adata.var_names = input.var_names
if input.uns["dataset_organism"] == "homo_sapiens":
adata.obs["organism_ontology_term_id"] = "NCBITaxon:9606"
elif input.uns["dataset_organism"] == "mus_musculus":
adata.obs["organism_ontology_term_id"] = "NCBITaxon:10090"
else:
raise ValueError(
f"scPRINT requires human or mouse data, not '{input.uns['dataset_organism']}'"
)

preprocessor = Preprocessor(
# Lower this threshold for test datasets
min_valid_genes_id=1000 if input.n_vars < 2000 else 10000,
# Turn off cell filtering to return results for all cells
filter_cell_by_counts=False,
min_nnz_genes=False,
do_postp=False,
# Skip ontology checks
skip_validate=True,
)
adata = preprocessor(adata)
print(adata)

model_checkpoint_file = par["model"]
if model_checkpoint_file is None:
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
model_checkpoint_file = hf_hub_download(
repo_id="jkobject/scPRINT", filename=f"{par['model_name']}.ckpt"
)
print(f"Model checkpoint file: '{model_checkpoint_file}'", flush=True)
model = scPrint.load_from_checkpoint(
model_checkpoint_file,
transformer="normal", # Don't use this for GPUs with flashattention
precpt_gene_emb=None,
)

print("\n>>> Denoising data...", flush=True)
if torch.cuda.is_available():
print("CUDA is available, using GPU", flush=True)
precision = "16-mixed"
dtype = torch.float16
else:
print("CUDA is not available, using CPU", flush=True)
precision = "32"
dtype = torch.float32
n_cores_available = len(os.sched_getaffinity(0))
print(f"Using {n_cores_available} worker cores")
denoiser = Denoiser(
num_workers=n_cores_available,
precision=precision,
max_cells=adata.n_obs,
doplot=False,
dtype=dtype,
)
_, idxs, genes, expr_pred = denoiser(model, adata)
print(f"Predicted expression dimensions: {expr_pred.shape}")

print("\n>>> Applying denoising...", flush=True)
adata.X = adata.X.tolil()
idxs = idxs if idxs is not None else range(adata.shape[0])
for i, idx in enumerate(idxs):
adata.X[idx, adata.var.index.get_indexer(genes)] = expr_pred[i]
adata.X = adata.X.tocsr()
print(adata)

print("\n>>> Storing output...", flush=True)
output = ad.AnnData(
layers={
"denoised": adata.X[:, adata.var.index.get_indexer(input.var_names)],
},
obs=input.obs[[]],
var=input.var[[]],
uns={
"dataset_id": input.uns["dataset_id"],
"method_id": meta["name"],
},
)
print(output)

print("\n>>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Done!", flush=True)
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ dependencies:
- name: methods/dca
- name: methods/knn_smoothing
- name: methods/magic
- name: methods/scprint
- name: metrics/mse
- name: metrics/poisson
runners:
Expand Down
3 changes: 2 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ workflow run_wf {
alra,
dca,
knn_smoothing,
magic
magic,
scprint
]

// construct list of metrics
Expand Down

0 comments on commit 252731b

Please sign in to comment.