generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Kai Waldrant <[email protected]>
- Loading branch information
1 parent
abb0d96
commit df6ad1d
Showing
9 changed files
with
362 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,17 @@ authors: | |
info: | ||
github: rcannood | ||
orcid: "0000-0003-3641-729X" | ||
- name: Xueer Chen | ||
roles: [ contributor ] | ||
info: | ||
github: xuerchen | ||
email: [email protected] | ||
- name: Jiwei Liu | ||
roles: [ contributor ] | ||
info: | ||
github: daxiongshu | ||
email: [email protected] | ||
orcid: "0000-0002-8799-9763" | ||
|
||
links: | ||
issue_tracker: https://github.com/openproblems-bio/task_predict_modality/issues | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
__merge__: ../../../api/comp_method_predict.yaml | ||
name: simplemlp_predict | ||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: ../resources/ | ||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
# run_args: ["--gpus all --ipc=host"] | ||
setup: | ||
- type: python | ||
pypi: | ||
- scikit-learn | ||
- scanpy | ||
- pytorch-lightning | ||
engines: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [highmem, hightime, midcpu, gpu, highsharedmem] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from glob import glob | ||
import sys | ||
import numpy as np | ||
from scipy.sparse import csc_matrix | ||
import anndata as ad | ||
import torch | ||
from torch.utils.data import TensorDataset,DataLoader | ||
|
||
## VIASH START | ||
par = { | ||
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad', | ||
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad', | ||
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad', | ||
'input_model': 'output/model', | ||
'output': 'output/prediction' | ||
} | ||
meta = { | ||
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp', | ||
'cpus': 10 | ||
} | ||
## VIASH END | ||
|
||
resources_dir = f"{meta['resources_dir']}/resources" | ||
sys.path.append(resources_dir) | ||
from models import MLP | ||
import utils | ||
|
||
def _predict(model,dl): | ||
model = model.cuda() | ||
model.eval() | ||
yps = [] | ||
for x in dl: | ||
with torch.no_grad(): | ||
yp = model(x[0].cuda()) | ||
yps.append(yp.detach().cpu().numpy()) | ||
yp = np.vstack(yps) | ||
return yp | ||
|
||
|
||
print('Load data', flush=True) | ||
input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) | ||
input_test_mod1 = ad.read_h5ad(par['input_test_mod1']) | ||
|
||
# determine variables | ||
mod_1 = input_test_mod1.uns['modality'] | ||
mod_2 = input_train_mod2.uns['modality'] | ||
|
||
task = f'{mod_1}2{mod_2}' | ||
|
||
print('Load ymean', flush=True) | ||
ymean_path = f"{par['input_model']}/{task}_ymean.npy" | ||
ymean = np.load(ymean_path) | ||
|
||
print('Start predict', flush=True) | ||
if task == 'GEX2ATAC': | ||
y_pred = ymean*np.ones([input_test_mod1.n_obs, input_test_mod1.n_vars]) | ||
else: | ||
folds = [0, 1, 2] | ||
|
||
ymean = torch.from_numpy(ymean).float() | ||
yaml_path=f"{resources_dir}/yaml/mlp_{task}.yaml" | ||
config = utils.load_yaml(yaml_path) | ||
X = input_test_mod1.layers["normalized"].toarray() | ||
X = torch.from_numpy(X).float() | ||
|
||
te_ds = TensorDataset(X) | ||
|
||
yp = 0 | ||
for fold in folds: | ||
# load_path = f"{par['input_model']}/{task}_fold_{fold}/version_0/checkpoints/*" | ||
load_path = f"{par['input_model']}/{task}_fold_{fold}/**.ckpt" | ||
print(load_path) | ||
ckpt = glob(load_path)[0] | ||
model_inf = MLP.load_from_checkpoint( | ||
ckpt, | ||
in_dim=X.shape[1], | ||
out_dim=input_test_mod1.n_vars, | ||
ymean=ymean, | ||
config=config | ||
) | ||
te_loader = DataLoader( | ||
te_ds, | ||
batch_size=config.batch_size, | ||
num_workers=0, | ||
shuffle=False, | ||
drop_last=False | ||
) | ||
yp = yp + _predict(model_inf, te_loader) | ||
|
||
y_pred = yp/len(folds) | ||
|
||
y_pred = csc_matrix(y_pred) | ||
|
||
adata = ad.AnnData( | ||
layers={"normalized": y_pred}, | ||
shape=y_pred.shape, | ||
uns={ | ||
'dataset_id': input_test_mod1.uns['dataset_id'], | ||
'method_id': meta['functionality_name'], | ||
}, | ||
) | ||
|
||
print('Write data', flush=True) | ||
adata.write_h5ad(par['output'], compression = "gzip") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
__merge__: ../../../api/comp_method_train.yaml | ||
name: simplemlp | ||
label: Simple MLP | ||
summary: Ensemble of MLPs trained on different sites (team AXX) | ||
description: | | ||
This folder contains the AXX solution to the OpenProblems-NeurIPS2021 Single-Cell Multimodal Data Integration. | ||
Team took the 4th place of the modality prediction task in terms of overall ranking of 4 subtasks: namely GEX | ||
to ADT, ADT to GEX, GEX to ATAC and ATAC to GEX. Specifically, our methods ranked 3rd in GEX to ATAC and 4th | ||
in GEX to ADT. More details about the task can be found in the | ||
[competition webpage](https://openproblems.bio/events/2021-09_neurips/documentation/about_tasks/task1_modality_prediction). | ||
references: | ||
doi: 10.1101/2022.04.11.487796 | ||
links: | ||
documentation: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX | ||
repository: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX | ||
info: | ||
preferred_normalization: log_cp10k | ||
competition_submission_id: 170812 | ||
resources: | ||
- path: main.nf | ||
type: nextflow_script | ||
entrypoint: run_wf | ||
dependencies: | ||
- name: predict_modality/methods/simplemlp_train | ||
- name: predict_modality/methods/simplemlp_predict | ||
runners: | ||
- type: nextflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
workflow run_wf { | ||
take: input_ch | ||
main: | ||
output_ch = input_ch | ||
|
||
| simplemlp_train.run( | ||
fromState: ["input_train_mod1", "input_train_mod2"], | ||
toState: ["input_model": "output"] | ||
) | ||
|
||
| simplemlp_predict.run( | ||
fromState: ["input_train_mod2", "input_test_mod1", "input_model", "input_transform"], | ||
toState: ["output": "output"] | ||
) | ||
|
||
| map { tup -> | ||
[tup[0], [output: tup[1].output]] | ||
} | ||
|
||
emit: output_ch | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
__merge__: ../../../api/comp_method_train.yaml | ||
name: simplemlp_train | ||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: ../resources/ | ||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
# run_args: ["--gpus all --ipc=host"] | ||
setup: | ||
- type: python | ||
pypi: | ||
- scikit-learn | ||
- scanpy | ||
- pytorch-lightning | ||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [highmem, hightime, midcpu, gpu, highsharedmem] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import os | ||
import math | ||
import logging | ||
from pathlib import Path | ||
|
||
import anndata as ad | ||
import numpy as np | ||
|
||
import torch | ||
import pytorch_lightning as pl | ||
from torch.utils.data import TensorDataset, DataLoader | ||
from pytorch_lightning.callbacks import ModelCheckpoint | ||
from pytorch_lightning.loggers import TensorBoardLogger,WandbLogger | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
## VIASH START | ||
par = { | ||
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad', | ||
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad', | ||
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad', | ||
'output': 'output/model' | ||
} | ||
meta = { | ||
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp', | ||
'cpus': 10 | ||
} | ||
## VIASH END | ||
|
||
resources_dir = f"{meta['resources_dir']}/resources" | ||
|
||
import sys | ||
sys.path.append(resources_dir) | ||
from models import MLP | ||
import utils | ||
|
||
def _train(X, y, Xt, yt, logger, config, num_workers): | ||
|
||
X = torch.from_numpy(X).float() | ||
y = torch.from_numpy(y).float() | ||
ymean = torch.mean(y, dim=0, keepdim=True) | ||
|
||
tr_ds = TensorDataset(X,y) | ||
tr_loader = DataLoader( | ||
tr_ds, | ||
batch_size=config.batch_size, | ||
num_workers=num_workers, | ||
shuffle=True, | ||
drop_last=True | ||
) | ||
|
||
Xt = torch.from_numpy(Xt).float() | ||
yt = torch.from_numpy(yt).float() | ||
te_ds = TensorDataset(Xt,yt) | ||
te_loader = DataLoader( | ||
te_ds, | ||
batch_size=config.batch_size, | ||
num_workers=num_workers, | ||
shuffle=False, | ||
drop_last=False | ||
) | ||
|
||
checkpoint_callback = ModelCheckpoint( | ||
monitor='valid_RMSE', | ||
dirpath=logger.save_dir, | ||
save_top_k=1, | ||
) | ||
|
||
trainer = pl.Trainer( | ||
devices="auto", | ||
enable_checkpointing=True, | ||
logger=logger, | ||
max_epochs=config.epochs, | ||
callbacks=[checkpoint_callback], | ||
default_root_dir=logger.save_dir, | ||
# progress_bar_refresh_rate=5 | ||
) | ||
|
||
net = MLP(X.shape[1], y.shape[1], ymean, config) | ||
trainer.fit(net, tr_loader, te_loader) | ||
|
||
yp = trainer.predict(net, te_loader, ckpt_path='best') | ||
yp = torch.cat(yp, dim=0) | ||
|
||
score = ((yp-yt)**2).mean()**0.5 | ||
print(f"VALID RMSE {score:.3f}") | ||
del trainer | ||
return score,yp.detach().numpy() | ||
|
||
|
||
|
||
input_train_mod1 = ad.read_h5ad(par['input_train_mod1']) | ||
input_train_mod2 = ad.read_h5ad(par['input_train_mod2']) | ||
|
||
mod_1 = input_train_mod1.uns["modality"] | ||
mod_2 = input_train_mod2.uns["modality"] | ||
|
||
task = f'{mod_1}2{mod_2}' | ||
yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml' | ||
|
||
obs_info = utils.to_site_donor(input_train_mod1) | ||
# TODO: if we want this method to work for other datasets, resolve dependence on site notation | ||
sites = obs_info.site.unique() | ||
|
||
os.makedirs(par['output'], exist_ok=True) | ||
|
||
print('Compute ymean', flush=True) | ||
ymean = np.asarray(input_train_mod2.layers["normalized"].mean(axis=0)) | ||
path = f"{par['output']}/{task}_ymean.npy" | ||
np.save(path, ymean) | ||
|
||
|
||
if task == "GEX2ATAC": | ||
logging.info(f"No training required for this task ({task}).") | ||
sys.exit(0) | ||
|
||
if not os.path.exists(yaml_path): | ||
logging.error(f"No configuration file found for task '{task}'") | ||
sys.exit(1) | ||
|
||
yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml' | ||
yps = [] | ||
scores = [] | ||
|
||
msgs = {} | ||
# TODO: if we want this method to work for other datasets, dont use hardcoded range | ||
for fold in range(3): | ||
|
||
run_name = f"{task}_fold_{fold}" | ||
save_path = f"{par['output']}/{run_name}" | ||
num_workers = meta["cpus"] or 0 | ||
|
||
Path(save_path).mkdir(parents=True, exist_ok=True) | ||
|
||
X,y,Xt,yt = utils.split(input_train_mod1, input_train_mod2, fold) | ||
|
||
logger = TensorBoardLogger(save_path, name='') | ||
|
||
config = utils.load_yaml(yaml_path) | ||
|
||
if config.batch_size > X.shape[0]: | ||
config = config._replace(batch_size=math.ceil(X.shape[0] / 2)) | ||
|
||
score, yp = _train(X, y, Xt, yt, logger, config, num_workers) | ||
yps.append(yp) | ||
scores.append(score) | ||
msg = f"{task} Fold {fold} RMSE {score:.3f}" | ||
msgs[f'Fold {fold}'] = f'{score:.3f}' | ||
print(msg) | ||
|
||
yp = np.concatenate(yps) | ||
score = np.mean(scores) | ||
msgs['Overall'] = f'{score:.3f}' | ||
print('Overall', f'{score:.3f}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters