Skip to content

Commit

Permalink
[pred_mod] Refactor guanlab method (#399)
Browse files Browse the repository at this point in the history
* WIP refactor script

* Add change comments

* Add to workflow
  • Loading branch information
KaiWaldrant authored Mar 12, 2024
1 parent 0ae327f commit fbfebe2
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 71 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
# The API specifies which type of component this is.
__merge__: ../../api/comp_method.yaml

functionality:

name: guanlab_dengkw_pm

info:
label: Guanlab-dengkw
summary: A kernel ridge regression method with RBF kernel.
Expand All @@ -14,8 +10,6 @@ functionality:
reference: lance2022multimodal
documentation_url: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/Guanlab-dengkw
repository_url: https://github.com/openproblems-bio/neurips2021_multimodal_topmethods/tree/main/src/predict_modality/methods/Guanlab-dengkw

# Component-specific parameters (optional)
arguments:
- name: "--distance_method"
type: "string"
Expand All @@ -26,11 +20,9 @@ functionality:
type: "integer"
default: 50
description: Number of components to use for dimensionality reduction.

resources:
- type: python_script
path: script.py

platforms:
- type: docker
image: ghcr.io/openproblems-bio/base_python:1.0.2
Expand All @@ -40,10 +32,6 @@ platforms:
- scikit-learn
- pandas
- numpy
- scanpy

- type: native

- type: nextflow
directives:
label: [ "hightime", highmem, highcpu]
88 changes: 30 additions & 58 deletions src/tasks/predict_modality/methods/guanlab_dengkw_pm/script.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import anndata as ad
import logging
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.decomposition import TruncatedSVD
import gc
from scipy.sparse import csc_matrix
from sklearn.gaussian_process.kernels import RBF
from sklearn.kernel_ridge import KernelRidge
logging.basicConfig(level=logging.INFO)

## VIASH START
par = {
Expand All @@ -21,11 +19,15 @@
}
## VIASH END


## Removed PCA and normalization steps, as they arr already performed with the input data
print('Reading input files', flush=True)
input_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])

dataset_id = input_train_mod1.uns['dataset_id']

pred_dimx = input_test_mod1.shape[0]
pred_dimy = input_train_mod2.shape[1]

Expand All @@ -48,55 +50,24 @@
index_unique="-"
)

logging.info('Determine parameters by the modalities')
mod1_type = input_train_mod1.uns["modality"]
mod1_type = mod1_type.upper()
mod2_type = input_train_mod2.uns["modality"]
mod2_type = mod2_type.upper()
n_comp_dict = {
("GEX", "ADT"): (300, 70, 10, 0.2),
("ADT", "GEX"): (None, 50, 10, 0.2),
("GEX", "ATAC"): (1000, 50, 10, 0.1),
("ATAC", "GEX"): (100, 70, 10, 0.1)
}
logging.info(f"{mod1_type}, {mod2_type}")
n_mod1, n_mod2, scale, alpha = n_comp_dict[(mod1_type, mod2_type)]
logging.info(f"{n_mod1}, {n_mod2}, {scale}, {alpha}")

# Perform PCA on the input data
logging.info('Models using the Truncated SVD to reduce the dimension')

if n_mod1 is not None and n_mod1 < input_train.shape[1]:
embedder_mod1 = TruncatedSVD(n_components=n_mod1)
mod1_pca = embedder_mod1.fit_transform(input_train.layers["counts"]).astype(np.float32)
train_matrix = mod1_pca[input_train.obs['group'] == 'train']
test_matrix = mod1_pca[input_train.obs['group'] == 'test']
else:
train_matrix = input_train_mod1.to_df(layer="counts").values.astype(np.float32)
test_matrix = input_test_mod1.to_df(layer="counts").values.astype(np.float32)

if n_mod2 is not None and n_mod2 < input_train_mod2.shape[1]:
embedder_mod2 = TruncatedSVD(n_components=n_mod2)
train_gs = embedder_mod2.fit_transform(input_train_mod2.layers["counts"]).astype(np.float32)
else:
train_gs = input_train_mod2.to_df(layer="counts").values.astype(np.float32)

del input_train

logging.info('Running normalization ...')
train_sd = np.std(train_matrix, axis=1).reshape(-1, 1)
train_sd[train_sd == 0] = 1
train_norm = (train_matrix - np.mean(train_matrix, axis=1).reshape(-1, 1)) / train_sd
train_norm = train_norm.astype(np.float32)
del train_matrix

test_sd = np.std(test_matrix, axis=1).reshape(-1, 1)
test_sd[test_sd == 0] = 1
test_norm = (test_matrix - np.mean(test_matrix, axis=1).reshape(-1, 1)) / test_sd
test_norm = test_norm.astype(np.float32)
del test_matrix

logging.info('Running KRR model ...')
print('Determine parameters by the modalities', flush=True)
mod1_type = input_train_mod1.uns["modality"].upper()
mod2_type = input_train_mod2.uns["modality"].upper()

scale = 10
alpha = 0.1 if (mod1_type == "ATAC" or mod2_type == "ATAC") else 0.2

train_norm = input_train_mod1.to_df(layer="normalized").values.astype(np.float32)
test_norm = input_test_mod1.to_df(layer="normalized").values.astype(np.float32)

train_gs = input_train_mod2.to_df(layer="normalized").values.astype(np.float32)

del input_train_mod1
del input_test_mod1
del input_train_mod2
gc.collect()

print('Running KRR model ...', flush=True)
y_pred = np.zeros((pred_dimx, pred_dimy), dtype=np.float32)
np.random.seed(1000)

Expand All @@ -107,12 +78,12 @@
if not batch:
batch = [batches[0]]

logging.info(batch)
print(batch, flush=True)
kernel = RBF(length_scale = scale)
krr = KernelRidge(alpha=alpha, kernel=kernel)
logging.info('Fitting KRR ... ')
print('Fitting KRR ... ', flush=True)
krr.fit(train_norm[feature_obs.batch.isin(batch)], train_gs[gs_obs.batch.isin(batch)])
y_pred += (krr.predict(test_norm) @ embedder_mod2.components_)
y_pred += krr.predict(test_norm)

np.clip(y_pred, a_min=0, a_max=None, out=y_pred)
if mod2_type == "ATAC":
Expand All @@ -123,7 +94,8 @@
# Store as sparse matrix to be efficient.
# Note that this might require different classifiers/embedders before-hand.
# Not every class is able to support such data structures.
y_pred = csr_matrix(y_pred)
## Changed from csr to csc matrix as this is more supported.
y_pred = csc_matrix(y_pred)

print("Write output AnnData to file", flush=True)
output = ad.AnnData(
Expand All @@ -133,7 +105,7 @@
obs = obs,
var = var,
uns = {
'dataset_id': input_train_mod1.uns['dataset_id'],
'dataset_id': dataset_id,
'method_id': meta['functionality_name']
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ functionality:
- name: predict_modality/methods/lm
- name: predict_modality/methods/newwave_knnr
- name: predict_modality/methods/random_forest
- name: predict_modality/methods/guanlab_dengkw_pm
- name: predict_modality/metrics/correlation
- name: predict_modality/metrics/mse
platforms:
Expand Down
3 changes: 2 additions & 1 deletion src/tasks/predict_modality/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ workflow run_wf {
knnr_r,
lm,
newwave_knnr,
random_forest
random_forest,
guanlab_dengkw_pm
]

// construct list of metrics
Expand Down

0 comments on commit fbfebe2

Please sign in to comment.