From dd18949c7853e53edb74009b2bb99c9849c94a11 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Tue, 10 Dec 2024 13:42:46 +0100 Subject: [PATCH] Add argument to give model path to scGPT (#16) * Add model path argument to scGPT * Use cached scGPT model in benchmark workflow * Make scGPT inherit from base method * Swap scGPT model argument names --- src/methods/scgpt/config.vsh.yaml | 13 ++++-- src/methods/scgpt/script.py | 62 +++++++++++++++++++++++------ src/workflows/run_benchmark/main.nf | 4 +- 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/src/methods/scgpt/config.vsh.yaml b/src/methods/scgpt/config.vsh.yaml index a6a7283..263a490 100644 --- a/src/methods/scgpt/config.vsh.yaml +++ b/src/methods/scgpt/config.vsh.yaml @@ -1,4 +1,4 @@ -__merge__: ../../api/comp_method.yaml +__merge__: ../../api/base_method.yaml name: scgpt label: scGPT @@ -24,11 +24,18 @@ info: model: "scGPT_CP" arguments: - - name: --model + - name: --model_name type: string - description: String giving the scGPT model to use + description: String giving the name of the scGPT model to use choices: ["scGPT_human", "scGPT_CP"] default: "scGPT_human" + - name: --model + type: file + description: | + Path to the directory containing the scGPT model specified by model_name + or a .zip/.tar.gz archive to extract. If not given the model will be + downloaded. + required: false - name: --n_hvg type: integer default: 3000 diff --git a/src/methods/scgpt/script.py b/src/methods/scgpt/script.py index 3729795..dda4877 100644 --- a/src/methods/scgpt/script.py +++ b/src/methods/scgpt/script.py @@ -1,5 +1,8 @@ +import os import sys +import tarfile import tempfile +import zipfile import anndata as ad import gdown @@ -12,6 +15,7 @@ par = { "input": "resources_test/.../input.h5ad", "output": "output.h5ad", + "model_name": "scGPT_human", "model": "scGPT_human", "n_hvg": 3000, } @@ -43,23 +47,54 @@ print(adata, flush=True) -print(f"\n>>> Downloading '{par['model']}' model...", flush=True) -model_drive_ids = { - "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", - "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", -} -drive_path = f"https://drive.google.com/drive/folders/{model_drive_ids[par['model']]}" -model_dir = tempfile.TemporaryDirectory() -print(f"Downloading from '{drive_path}'", flush=True) -gdown.download_folder(drive_path, output=model_dir.name, quiet=True) -print(f"Model directory: '{model_dir.name}'", flush=True) +if par["model"] is None: + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) + model_drive_ids = { + "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", + "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", + } + drive_path = ( + f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}" + ) + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + print(f"Downloading from '{drive_path}'", flush=True) + gdown.download_folder(drive_path, output=model_dir, quiet=True) +else: + if os.path.isdir(par["model"]): + print(f"\n>>> Using model directory...", flush=True) + model_temp = None + model_dir = par["model"] + else: + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + + if zipfile.is_zipfile(par["model"]): + print(f"\n>>> Extracting model from .zip...", flush=True) + print(f".zip path: '{par['model']}'", flush=True) + with zipfile.ZipFile(par["model"], "r") as zip_file: + zip_file.extractall(model_dir) + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith( + ".tar.gz" + ): + print(f"\n>>> Extracting model from .tar.gz...", flush=True) + print(f".tar.gz path: '{par['model']}'", flush=True) + with tarfile.open(par["model"], "r:gz") as tar_file: + tar_file.extractall(model_dir) + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) + else: + raise ValueError( + f"The 'model' argument should be a directory a .zip file or a .tar.gz file" + ) + +print(f"Model directory: '{model_dir}'", flush=True) print("\n>>> Embedding data...", flush=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: '{device}'", flush=True) embedded = scgpt.tasks.embed_data( adata, - model_dir.name, + model_dir, gene_col="feature_name", batch_size=64, use_fast_transformer=False, # Disable fast-attn as not installed @@ -86,7 +121,8 @@ print(f"Output H5AD file: '{par['output']}'", flush=True) output.write_h5ad(par["output"], compression="gzip") -print("\n>>> Cleaning up temporary directories...", flush=True) -model_dir.cleanup() +if model_temp is not None: + print("\n>>> Cleaning up temporary directories...", flush=True) + model_temp.cleanup() print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index afcb968..5252216 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -29,7 +29,9 @@ methods = [ scalex, scanorama, scanvi, - scgpt, + scgpt.run( + args: [model_path: file("s3://openproblems-work/cache/scGPT_human.zip")] + ), scimilarity.run( args: [model: file("s3://openproblems-work/cache/scimilarity-model_v1.1.tar.gz")] ),