Skip to content

Commit

Permalink
Add argument to give model path to scGPT (#16)
Browse files Browse the repository at this point in the history
* Add model path argument to scGPT

* Use cached scGPT model in benchmark workflow

* Make scGPT inherit from base method

* Swap scGPT model argument names
  • Loading branch information
lazappi authored Dec 10, 2024
1 parent 52ccedb commit dd18949
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 17 deletions.
13 changes: 10 additions & 3 deletions src/methods/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__merge__: ../../api/comp_method.yaml
__merge__: ../../api/base_method.yaml

name: scgpt
label: scGPT
Expand All @@ -24,11 +24,18 @@ info:
model: "scGPT_CP"

arguments:
- name: --model
- name: --model_name
type: string
description: String giving the scGPT model to use
description: String giving the name of the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --model
type: file
description: |
Path to the directory containing the scGPT model specified by model_name
or a .zip/.tar.gz archive to extract. If not given the model will be
downloaded.
required: false
- name: --n_hvg
type: integer
default: 3000
Expand Down
62 changes: 49 additions & 13 deletions src/methods/scgpt/script.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import sys
import tarfile
import tempfile
import zipfile

import anndata as ad
import gdown
Expand All @@ -12,6 +15,7 @@
par = {
"input": "resources_test/.../input.h5ad",
"output": "output.h5ad",
"model_name": "scGPT_human",
"model": "scGPT_human",
"n_hvg": 3000,
}
Expand Down Expand Up @@ -43,23 +47,54 @@

print(adata, flush=True)

print(f"\n>>> Downloading '{par['model']}' model...", flush=True)
model_drive_ids = {
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}"
model_dir = tempfile.TemporaryDirectory()
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir.name, quiet=True)
print(f"Model directory: '{model_dir.name}'", flush=True)
if par["model"] is None:
print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True)
model_drive_ids = {
"scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y",
"scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB",
}
drive_path = (
f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}"
)
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name
print(f"Downloading from '{drive_path}'", flush=True)
gdown.download_folder(drive_path, output=model_dir, quiet=True)
else:
if os.path.isdir(par["model"]):
print(f"\n>>> Using model directory...", flush=True)
model_temp = None
model_dir = par["model"]
else:
model_temp = tempfile.TemporaryDirectory()
model_dir = model_temp.name

if zipfile.is_zipfile(par["model"]):
print(f"\n>>> Extracting model from .zip...", flush=True)
print(f".zip path: '{par['model']}'", flush=True)
with zipfile.ZipFile(par["model"], "r") as zip_file:
zip_file.extractall(model_dir)
elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(
".tar.gz"
):
print(f"\n>>> Extracting model from .tar.gz...", flush=True)
print(f".tar.gz path: '{par['model']}'", flush=True)
with tarfile.open(par["model"], "r:gz") as tar_file:
tar_file.extractall(model_dir)
model_dir = os.path.join(model_dir, os.listdir(model_dir)[0])
else:
raise ValueError(
f"The 'model' argument should be a directory a .zip file or a .tar.gz file"
)

print(f"Model directory: '{model_dir}'", flush=True)

print("\n>>> Embedding data...", flush=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: '{device}'", flush=True)
embedded = scgpt.tasks.embed_data(
adata,
model_dir.name,
model_dir,
gene_col="feature_name",
batch_size=64,
use_fast_transformer=False, # Disable fast-attn as not installed
Expand All @@ -86,7 +121,8 @@
print(f"Output H5AD file: '{par['output']}'", flush=True)
output.write_h5ad(par["output"], compression="gzip")

print("\n>>> Cleaning up temporary directories...", flush=True)
model_dir.cleanup()
if model_temp is not None:
print("\n>>> Cleaning up temporary directories...", flush=True)
model_temp.cleanup()

print("\n>>> Done!", flush=True)
4 changes: 3 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ methods = [
scalex,
scanorama,
scanvi,
scgpt,
scgpt.run(
args: [model_path: file("s3://openproblems-work/cache/scGPT_human.zip")]
),
scimilarity.run(
args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")]
),
Expand Down

0 comments on commit dd18949

Please sign in to comment.