Skip to content

Commit

Permalink
add simple_mlp
Browse files Browse the repository at this point in the history
Co-authored-by: Kai Waldrant <[email protected]>
  • Loading branch information
rcannood and KaiWaldrant committed Aug 26, 2024
1 parent abb0d96 commit df6ad1d
Show file tree
Hide file tree
Showing 9 changed files with 362 additions and 1 deletion.
11 changes: 11 additions & 0 deletions _viash.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ authors:
info:
github: rcannood
orcid: "0000-0003-3641-729X"
- name: Xueer Chen
roles: [ contributor ]
info:
github: xuerchen
email: [email protected]
- name: Jiwei Liu
roles: [ contributor ]
info:
github: daxiongshu
email: [email protected]
orcid: "0000-0002-8799-9763"

links:
issue_tracker: https://github.com/openproblems-bio/task_predict_modality/issues
Expand Down
21 changes: 21 additions & 0 deletions src/methods/simple_mlp/predict/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
__merge__: ../../../api/comp_method_predict.yaml
name: simplemlp_predict
resources:
- type: python_script
path: script.py
- path: ../resources/
engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# run_args: ["--gpus all --ipc=host"]
setup:
- type: python
pypi:
- scikit-learn
- scanpy
- pytorch-lightning
engines:
- type: executable
- type: nextflow
directives:
label: [highmem, hightime, midcpu, gpu, highsharedmem]
104 changes: 104 additions & 0 deletions src/methods/simple_mlp/predict/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from glob import glob
import sys
import numpy as np
from scipy.sparse import csc_matrix
import anndata as ad
import torch
from torch.utils.data import TensorDataset,DataLoader

## VIASH START
par = {
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad',
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad',
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad',
'input_model': 'output/model',
'output': 'output/prediction'
}
meta = {
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp',
'cpus': 10
}
## VIASH END

resources_dir = f"{meta['resources_dir']}/resources"
sys.path.append(resources_dir)
from models import MLP
import utils

def _predict(model,dl):
model = model.cuda()
model.eval()
yps = []
for x in dl:
with torch.no_grad():
yp = model(x[0].cuda())
yps.append(yp.detach().cpu().numpy())
yp = np.vstack(yps)
return yp


print('Load data', flush=True)
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])

# determine variables
mod_1 = input_test_mod1.uns['modality']
mod_2 = input_train_mod2.uns['modality']

task = f'{mod_1}2{mod_2}'

print('Load ymean', flush=True)
ymean_path = f"{par['input_model']}/{task}_ymean.npy"
ymean = np.load(ymean_path)

print('Start predict', flush=True)
if task == 'GEX2ATAC':
y_pred = ymean*np.ones([input_test_mod1.n_obs, input_test_mod1.n_vars])
else:
folds = [0, 1, 2]

ymean = torch.from_numpy(ymean).float()
yaml_path=f"{resources_dir}/yaml/mlp_{task}.yaml"
config = utils.load_yaml(yaml_path)
X = input_test_mod1.layers["normalized"].toarray()
X = torch.from_numpy(X).float()

te_ds = TensorDataset(X)

yp = 0
for fold in folds:
# load_path = f"{par['input_model']}/{task}_fold_{fold}/version_0/checkpoints/*"
load_path = f"{par['input_model']}/{task}_fold_{fold}/**.ckpt"
print(load_path)
ckpt = glob(load_path)[0]
model_inf = MLP.load_from_checkpoint(
ckpt,
in_dim=X.shape[1],
out_dim=input_test_mod1.n_vars,
ymean=ymean,
config=config
)
te_loader = DataLoader(
te_ds,
batch_size=config.batch_size,
num_workers=0,
shuffle=False,
drop_last=False
)
yp = yp + _predict(model_inf, te_loader)

y_pred = yp/len(folds)

y_pred = csc_matrix(y_pred)

adata = ad.AnnData(
layers={"normalized": y_pred},
shape=y_pred.shape,
uns={
'dataset_id': input_test_mod1.uns['dataset_id'],
'method_id': meta['functionality_name'],
},
)

print('Write data', flush=True)
adata.write_h5ad(par['output'], compression = "gzip")
27 changes: 27 additions & 0 deletions src/methods/simple_mlp/run/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
__merge__: ../../../api/comp_method_train.yaml
name: simplemlp
label: Simple MLP
summary: Ensemble of MLPs trained on different sites (team AXX)
description: |
This folder contains the AXX solution to the OpenProblems-NeurIPS2021 Single-Cell Multimodal Data Integration.
Team took the 4th place of the modality prediction task in terms of overall ranking of 4 subtasks: namely GEX
to ADT, ADT to GEX, GEX to ATAC and ATAC to GEX. Specifically, our methods ranked 3rd in GEX to ATAC and 4th
in GEX to ADT. More details about the task can be found in the
[competition webpage](https://openproblems.bio/events/2021-09_neurips/documentation/about_tasks/task1_modality_prediction).
references:
doi: 10.1101/2022.04.11.487796
links:
documentation: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
repository: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/AXX
info:
preferred_normalization: log_cp10k
competition_submission_id: 170812
resources:
- path: main.nf
type: nextflow_script
entrypoint: run_wf
dependencies:
- name: predict_modality/methods/simplemlp_train
- name: predict_modality/methods/simplemlp_predict
runners:
- type: nextflow
21 changes: 21 additions & 0 deletions src/methods/simple_mlp/run/main.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
workflow run_wf {
take: input_ch
main:
output_ch = input_ch

| simplemlp_train.run(
fromState: ["input_train_mod1", "input_train_mod2"],
toState: ["input_model": "output"]
)

| simplemlp_predict.run(
fromState: ["input_train_mod2", "input_test_mod1", "input_model", "input_transform"],
toState: ["output": "output"]
)

| map { tup ->
[tup[0], [output: tup[1].output]]
}

emit: output_ch
}
21 changes: 21 additions & 0 deletions src/methods/simple_mlp/train/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
__merge__: ../../../api/comp_method_train.yaml
name: simplemlp_train
resources:
- type: python_script
path: script.py
- path: ../resources/
engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# run_args: ["--gpus all --ipc=host"]
setup:
- type: python
pypi:
- scikit-learn
- scanpy
- pytorch-lightning
runners:
- type: executable
- type: nextflow
directives:
label: [highmem, hightime, midcpu, gpu, highsharedmem]
154 changes: 154 additions & 0 deletions src/methods/simple_mlp/train/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import math
import logging
from pathlib import Path

import anndata as ad
import numpy as np

import torch
import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger,WandbLogger

logging.basicConfig(level=logging.INFO)

## VIASH START
par = {
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod1.h5ad',
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/train_mod2.h5ad',
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_multiome/swap/test_mod1.h5ad',
'output': 'output/model'
}
meta = {
'resources_dir': 'src/tasks/predict_modality/methods/simple_mlp',
'cpus': 10
}
## VIASH END

resources_dir = f"{meta['resources_dir']}/resources"

import sys
sys.path.append(resources_dir)
from models import MLP
import utils

def _train(X, y, Xt, yt, logger, config, num_workers):

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
ymean = torch.mean(y, dim=0, keepdim=True)

tr_ds = TensorDataset(X,y)
tr_loader = DataLoader(
tr_ds,
batch_size=config.batch_size,
num_workers=num_workers,
shuffle=True,
drop_last=True
)

Xt = torch.from_numpy(Xt).float()
yt = torch.from_numpy(yt).float()
te_ds = TensorDataset(Xt,yt)
te_loader = DataLoader(
te_ds,
batch_size=config.batch_size,
num_workers=num_workers,
shuffle=False,
drop_last=False
)

checkpoint_callback = ModelCheckpoint(
monitor='valid_RMSE',
dirpath=logger.save_dir,
save_top_k=1,
)

trainer = pl.Trainer(
devices="auto",
enable_checkpointing=True,
logger=logger,
max_epochs=config.epochs,
callbacks=[checkpoint_callback],
default_root_dir=logger.save_dir,
# progress_bar_refresh_rate=5
)

net = MLP(X.shape[1], y.shape[1], ymean, config)
trainer.fit(net, tr_loader, te_loader)

yp = trainer.predict(net, te_loader, ckpt_path='best')
yp = torch.cat(yp, dim=0)

score = ((yp-yt)**2).mean()**0.5
print(f"VALID RMSE {score:.3f}")
del trainer
return score,yp.detach().numpy()



input_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])

mod_1 = input_train_mod1.uns["modality"]
mod_2 = input_train_mod2.uns["modality"]

task = f'{mod_1}2{mod_2}'
yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml'

obs_info = utils.to_site_donor(input_train_mod1)
# TODO: if we want this method to work for other datasets, resolve dependence on site notation
sites = obs_info.site.unique()

os.makedirs(par['output'], exist_ok=True)

print('Compute ymean', flush=True)
ymean = np.asarray(input_train_mod2.layers["normalized"].mean(axis=0))
path = f"{par['output']}/{task}_ymean.npy"
np.save(path, ymean)


if task == "GEX2ATAC":
logging.info(f"No training required for this task ({task}).")
sys.exit(0)

if not os.path.exists(yaml_path):
logging.error(f"No configuration file found for task '{task}'")
sys.exit(1)

yaml_path = f'{resources_dir}/yaml/mlp_{task}.yaml'
yps = []
scores = []

msgs = {}
# TODO: if we want this method to work for other datasets, dont use hardcoded range
for fold in range(3):

run_name = f"{task}_fold_{fold}"
save_path = f"{par['output']}/{run_name}"
num_workers = meta["cpus"] or 0

Path(save_path).mkdir(parents=True, exist_ok=True)

X,y,Xt,yt = utils.split(input_train_mod1, input_train_mod2, fold)

logger = TensorBoardLogger(save_path, name='')

config = utils.load_yaml(yaml_path)

if config.batch_size > X.shape[0]:
config = config._replace(batch_size=math.ceil(X.shape[0] / 2))

score, yp = _train(X, y, Xt, yt, logger, config, num_workers)
yps.append(yp)
scores.append(score)
msg = f"{task} Fold {fold} RMSE {score:.3f}"
msgs[f'Fold {fold}'] = f'{score:.3f}'
print(msg)

yp = np.concatenate(yps)
score = np.mean(scores)
msgs['Overall'] = f'{score:.3f}'
print('Overall', f'{score:.3f}')
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies:
- name: methods/lm
- name: methods/lmds_irlba_rf
- name: methods/guanlab_dengkw_pm
- name: methods/simple_mlp
- name: metrics/correlation
- name: metrics/mse
runners:
Expand Down
3 changes: 2 additions & 1 deletion src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ workflow run_wf {
knnr_r,
lm,
lmds_irlba_rf,
guanlab_dengkw_pm
guanlab_dengkw_pm,
simple_mlp
]

// construct list of metrics
Expand Down

0 comments on commit df6ad1d

Please sign in to comment.