From a63a2eb45af5f1e78e8d7258ab9d60f2c09e30a7 Mon Sep 17 00:00:00 2001 From: jalil Date: Tue, 10 Sep 2024 09:59:46 +0200 Subject: [PATCH] workflow updated --- runs.ipynb | 11 ++---- scripts/run_grn_evaluation.sh | 34 ++++++++----------- src/api/comp_metric.yaml | 4 +++ src/exp_analysis/script.py | 1 + src/methods/single_omics/scgpt/script.py | 11 ++++++ src/metrics/regression_2/main.py | 28 ++++++++++----- .../run_grn_evaluation/config.vsh.yaml | 16 ++++----- src/workflows/run_grn_evaluation/main.nf | 3 ++ 8 files changed, 62 insertions(+), 46 deletions(-) diff --git a/runs.ipynb b/runs.ipynb index a51051b6a..1b0aa2a5d 100644 --- a/runs.ipynb +++ b/runs.ipynb @@ -2177,20 +2177,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "!aws s3 sync s3://openproblems-data/resources_test/grn/results/s3://openproblems-data/resources/grn/results/single_omics_all ./resources_test/results/s3://openproblems-data/resources/grn/results/single_omics_all" + "!aws s3 sync s3://openproblems-data/resources/grn/results/single_omics_all resources/results/single_omics_all" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/scripts/run_grn_evaluation.sh b/scripts/run_grn_evaluation.sh index eeec944a1..5ee9d29e5 100755 --- a/scripts/run_grn_evaluation.sh +++ b/scripts/run_grn_evaluation.sh @@ -3,28 +3,22 @@ # RUN_ID="run_$(date +%Y-%m-%d_%H-%M-%S)" reg_type=${1} #GB, ridge -RUN_ID="grn_evaluation_${reg_type}" -resources_dir="s3://openproblems-data/resources/grn" -# resources_dir="./resources" +RUN_ID="grn_evaluation_so_${reg_type}" +# resources_dir="s3://openproblems-data/resources/grn" +resources_dir="./resources" publish_dir="${resources_dir}/results/${RUN_ID}" grn_models_folder="${resources_dir}/grn_models" subsample=-2 max_workers=10 +layer=pearson +metric_ids="[regression_1]" -param_file="./params/${RUN_ID}_figr.yaml" - -# grn_names=( -# "collectri" -# "celloracle" -# "scenicplus" -# "figr" -# "granie" -# "scglue" -# ) +param_file="./params/${RUN_ID}.yaml" grn_names=( - "figr") + "genie3" + ) # Start writing to the YAML file cat > $param_file << HERE param_list: @@ -33,6 +27,7 @@ HERE append_entry() { cat >> $param_file << HERE - id: ${reg_type}_${1}_${3} + metric_ids: ${metric_ids} perturbation_data: ${resources_dir}/grn-benchmark/perturbation_data.h5ad reg_type: $reg_type method_id: $1 @@ -41,6 +36,7 @@ append_entry() { tf_all: ${resources_dir}/prior/tf_all.csv layer: ${3} consensus: ${resources_dir}/prior/consensus-num-regulators.json + HERE # Conditionally append the prediction line if the second argument is "true" @@ -50,14 +46,14 @@ HERE HERE fi } -layers=(scgen_pearson) + # Loop through grn_names and layers -for layer in "${layers[@]}"; do - for grn_name in "${grn_names[@]}"; do - append_entry "$grn_name" "true" "$layer" - done + +for grn_name in "${grn_names[@]}"; do + append_entry "$grn_name" "true" "$layer" done + # # Append negative control # grn_name="negative_control" # for layer in "${layers[@]}"; do diff --git a/src/api/comp_metric.yaml b/src/api/comp_metric.yaml index 6d6545998..bc8751586 100644 --- a/src/api/comp_metric.yaml +++ b/src/api/comp_metric.yaml @@ -12,14 +12,17 @@ functionality: __merge__: file_perturbation_h5ad.yaml required: false direction: input + default: resources/grn-benchmark/perturbation_data.h5ad - name: --prediction __merge__: file_prediction.yaml required: true direction: input + - name: --score __merge__: file_score.yaml required: false direction: output + default: output/score.h5ad - name: --reg_type type: string direction: input @@ -44,6 +47,7 @@ functionality: type: file direction: input example: resources_test/prior/tf_all.csv + default: resources/prior/tf_all.csv - name: --apply_tf type: boolean required: false diff --git a/src/exp_analysis/script.py b/src/exp_analysis/script.py index 966dc4dec..66e2eedc6 100644 --- a/src/exp_analysis/script.py +++ b/src/exp_analysis/script.py @@ -33,6 +33,7 @@ info_obj = Explanatory_analysis(net=tf_gene_net) print("Calculate basic stats") stats = info_obj.calculate_basic_stats() +print(stats) print("Outputting stats to :", par['stats']) with open(par['stats'], 'w') as ff: json.dump(stats, ff) diff --git a/src/methods/single_omics/scgpt/script.py b/src/methods/single_omics/scgpt/script.py index 9aa0665d3..066a30da8 100644 --- a/src/methods/single_omics/scgpt/script.py +++ b/src/methods/single_omics/scgpt/script.py @@ -14,6 +14,7 @@ import networkx as nx import pandas as pd import tqdm +import os # import gseapy as gp # from gears import PertData, GEARS @@ -54,6 +55,13 @@ } ## VIASH END +# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:50" +initial_memory = torch.cuda.memory_allocated() +def monitor_memory(): + used_memory = torch.cuda.memory_allocated() + data_moved = used_memory - initial_memory + print(f"Data moved to GPU: {data_moved} bytes") + # Load list of putative TFs tf_all = np.loadtxt(par['tf_all'], dtype=str) @@ -128,6 +136,7 @@ model.load_state_dict(model_dict) model.to(device) +monitor_memory() print('Process rna-seq file') @@ -201,6 +210,7 @@ dict_sum_condition = {} print('Extract gene gene links from attention layer') model.eval() +monitor_memory() with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): M = all_gene_ids.size(1) N = all_gene_ids.size(0) @@ -210,6 +220,7 @@ outputs = np.zeros((batch_size, M, M), dtype=np.float32) # Replicate the operations in model forward pass src_embs = model.encoder(torch.tensor(all_gene_ids[i : i + batch_size], dtype=torch.long).to(device)) + # monitor_memory() val_embs = model.value_encoder(torch.tensor(all_values[i : i + batch_size], dtype=torch.float).to(device)) total_embs = src_embs + val_embs total_embs = model.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) diff --git a/src/metrics/regression_2/main.py b/src/metrics/regression_2/main.py index 11d2400b8..ba5617d25 100644 --- a/src/metrics/regression_2/main.py +++ b/src/metrics/regression_2/main.py @@ -219,15 +219,24 @@ def static_approach( gene_names: List[str], tf_names: Set[str], reg_type: str, - n_jobs:int + n_jobs:int, + n_features_dict:dict ) -> float: # Cross-validate each gene using the inferred GRN to define select input features res = cross_validate(reg_type, gene_names, tf_names, X, groups, grn, n_features, n_jobs=n_jobs) - mean_r2_scores = np.asarray([res['scores'][j]['avg-r2'] for j in range(len(res['scores']))]) - mean_r2_scores = mean_r2_scores[mean_r2_scores>-10] + r2 = [] - return np.mean(mean_r2_scores) + for i in range(len(res['scores'])): + gene_name = res['gene_names'][i] + if n_features[n_features_dict[gene_name]] != 0: + r2.append(res['scores'][i]['avg-r2']) + + # mean_r2_scores = np.asarray([res['scores'][j]['avg-r2'] for j in range(len(res['scores']))]) + mean_r2_scores = float(np.mean(r2)) + + # return np.mean(mean_r2_scores) + return mean_r2_scores def main(par: Dict[str, Any]) -> pd.DataFrame: @@ -272,6 +281,9 @@ def main(par: Dict[str, Any]) -> pd.DataFrame: # Load consensus numbers of putative regulators with open(par['consensus'], 'r') as f: data = json.load(f) + gene_names_ = np.asarray(list(data.keys()), dtype=object) + n_features_dict = {gene_name: i for i, gene_name in enumerate(gene_names_)} + n_features_theta_min = np.asarray([data[gene_name]['0'] for gene_name in gene_names], dtype=int) n_features_theta_median = np.asarray([data[gene_name]['0.5'] for gene_name in gene_names], dtype=int) n_features_theta_max = np.asarray([data[gene_name]['1'] for gene_name in gene_names], dtype=int) @@ -284,16 +296,16 @@ def main(par: Dict[str, Any]) -> pd.DataFrame: # Evaluate GRN print(f'Compute metrics for layer: {layer}', flush=True) # print(f'Dynamic approach:', flush=True) - # print(f'Static approach (theta=0):', flush=True) - score_static_min = static_approach(grn, n_features_theta_min, X, groups, gene_names, tf_names, par['reg_type'], n_jobs=par['max_workers']) + print(f'Static approach (theta=0):', flush=True) + score_static_min = static_approach(grn, n_features_theta_min, X, groups, gene_names, tf_names, par['reg_type'], n_jobs=par['max_workers'], n_features_dict=n_features_dict) print(f'Static approach (theta=0.5):', flush=True) - score_static_median = static_approach(grn, n_features_theta_median, X, groups, gene_names, tf_names, par['reg_type'], n_jobs=par['max_workers']) + score_static_median = static_approach(grn, n_features_theta_median, X, groups, gene_names, tf_names, par['reg_type'], n_jobs=par['max_workers'], n_features_dict=n_features_dict) # print(f'Static approach (theta=1):', flush=True) # score_static_max = static_approach(grn, n_features_theta_max, X, groups, gene_names, tf_names, par['reg_type'], n_jobs=par['max_workers']) # TODO: find a mathematically sound way to combine Z-scores and r2 scores results = { - # 'static-theta-0.0': [float(score_static_min)], + 'static-theta-0.0': [float(score_static_min)], 'static-theta-0.5': [float(score_static_median)] # 'static-theta-1.0': [float(score_static_max)], } diff --git a/src/workflows/run_grn_evaluation/config.vsh.yaml b/src/workflows/run_grn_evaluation/config.vsh.yaml index 0dd99f945..ec7da7958 100644 --- a/src/workflows/run_grn_evaluation/config.vsh.yaml +++ b/src/workflows/run_grn_evaluation/config.vsh.yaml @@ -62,16 +62,12 @@ functionality: direction: output default: metric_configs.yaml - # - name: Arguments - # arguments: - # - name: "--predictions" - # type: string - # multiple: true - # description: A list of GRN models - # - name: "--layers" - # type: string - # multiple: true - # description: A list of GRN models + - name: Arguments + arguments: + - name: "--metric_ids" + type: string + multiple: true + description: A list of metric ids to run. If not specified, all metric will be run. resources: - type: nextflow_script diff --git a/src/workflows/run_grn_evaluation/main.nf b/src/workflows/run_grn_evaluation/main.nf index 3fac0b295..20d1cc7c5 100644 --- a/src/workflows/run_grn_evaluation/main.nf +++ b/src/workflows/run_grn_evaluation/main.nf @@ -61,6 +61,9 @@ workflow run_wf { id: { id, state, comp -> id + "." + comp.config.functionality.name }, + filter: { id, state, comp -> + !state.metric_ids || state.metric_ids.contains(comp.config.functionality.name) + }, // use 'fromState' to fetch the arguments the component requires from the overall state fromState: [ perturbation_data: "perturbation_data",