From d4e30fec1b8e72846692a41ae29fe54fa7dc51e1 Mon Sep 17 00:00:00 2001 From: Robrecht Cannoodt Date: Mon, 23 Sep 2024 23:08:51 +0200 Subject: [PATCH] initial benchmark implementation --- scripts/run_benchmark/run_test_local.sh | 15 +- src/workflows/run_benchmark/config.vsh.yaml | 24 +- src/workflows/run_benchmark/main.nf | 609 ++++++++++++-------- 3 files changed, 389 insertions(+), 259 deletions(-) diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index ca86340a..c0bcb94d 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -6,14 +6,6 @@ REPO_ROOT=$(git rev-parse --show-toplevel) # ensure that the command below is run from the root of the repository cd "$REPO_ROOT" -# remove this when you have implemented the script -echo "TODO: once the 'run_benchmark' workflow has been implemented, update this script to use it." -echo " Step 1: replace 'task_template' with the name of the task in the following command." -echo " Step 2: replace the rename keys parameters to fit your run_benchmark inputs" -echo " Step 3: replace the settings parameter to fit your run_benchmark outputs" -echo " Step 4: remove this message" -exit 1 - set -e echo "Running benchmark on test data" @@ -28,9 +20,8 @@ nextflow run . \ -profile docker \ -resume \ -c common/nextflow_helpers/labels_ci.config \ - --id cxg_mouse_pancreas_atlas \ - --input_train resources_test/task_template/cxg_mouse_pancreas_atlas/train.h5ad \ - --input_test resources_test/task_template/cxg_mouse_pancreas_atlas/test.h5ad \ - --input_solution resources_test/task_template/cxg_mouse_pancreas_atlas/solution.h5ad \ + --id mouse_brain_combined \ + --input_sc resources_test/task_ist_preprocessing/mouse_brain_combined/scrnaseq_reference.h5ad \ + --input_sp resources_test/task_ist_preprocessing/mouse_brain_combined/raw_ist.zarr \ --output_state state.yaml \ --publish_dir "$publish_dir" diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index cd026cf8..0eaa3526 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -42,12 +42,6 @@ argument_groups: required: true direction: output default: task_info.yaml - - name: Methods - arguments: - - name: "--method_ids" - type: string - multiple: true - description: A list of method ids to run. If not specified, all methods will be run. resources: - type: nextflow_script @@ -57,13 +51,17 @@ resources: path: /_viash.yaml dependencies: - # - name: common/check_dataset_schema - # repository: openproblems-v2 - # - name: common/extract_metadata - # repository: openproblems-v2 - # - name: control_methods/true_labels - # - name: methods/logistic_regression - # - name: metrics/accuracy + - name: h5ad/extract_uns_metadata + repository: core + - name: methods_segmentation/custom_segmentation + - name: methods_transcript_assignment/basic_transcript_assignment + - name: methods_count_aggregation/basic_count_aggregation + - name: methods_qc_filter/basic_qc_filter + - name: methods_calculate_cell_volume/alpha_shapes + - name: methods_normalization/normalize_by_volume + - name: methods_cell_type_annotation/ssam + - name: methods_expression_correction/gene_efficiency_correction + - name: metrics/similarity runners: - type: nextflow diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index d3d0310c..58116229 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -1,5 +1,5 @@ workflow auto { - findStatesTemp(params, meta.config) + findStates(params, meta.config) | meta.workflow.run( auto: [publish: "state"] ) @@ -11,135 +11,356 @@ workflow run_wf { main: - // construct list of methods - methods = [ - true_labels, - logistic_regression - ] - - // construct list of metrics - metrics = [ - accuracy - ] - - /**************************** - * EXTRACT DATASET METADATA * - ****************************/ - dataset_ch = input_ch - // store join id - | map{ id, state -> - [id, state + ["_meta": [join_id: id]]] + /**************************************** + * INITIALIZE DATA STRUCTURES * + ****************************************/ + init_ch = input_ch + | map { id, state -> + def new_state = state + [ + steps: [ + [type: "dataset", dataset_id: id] + ] + ] + [id, new_state] } - // extract the dataset metadata - | extract_metadata.run( - fromState: [input: "input_solution"], - toState: { id, output, state -> + /**************************************** + * RUN SEGMENTATION METHODS * + ****************************************/ + segmentation_methods = [ + custom_segmentation.run( + args: [labels_key: "cell_labels"] + ) + ] + segm_ch = init_ch + | runEach( + components: segmentation_methods, + id: { id, state, comp -> + id + "/segm_" + comp.name + }, + fromState: ["input": "input_sp"], + toState: { id, out_dict, state, comp -> state + [ - dataset_uns: readYaml(output.output).uns + steps: state.steps + [[ + type: "segmentation", + component_id: comp.name, + run_id: id + ]], + output_segmentation: out_dict.output ] } ) - /*************************** - * RUN METHODS AND METRICS * - ***************************/ - score_ch = dataset_ch - - // run all methods + /**************************************** + * RUN ASSIGNMENT AFTER SEGMENTATION * + ****************************************/ + segm_ass_methods = [ + basic_transcript_assignment.run( + args: [ + transcripts_key: "transcripts", + coordinate_system: "global" + ] + ) + ] + segm_ass_ch = segm_ch | runEach( - components: methods, - - // use the 'filter' argument to only run a method on the normalization the component is asking for - filter: { id, state, comp -> - def norm = state.dataset_uns.normalization_id - def pref = comp.config.info.preferred_normalization - // if the preferred normalization is none at all, - // we can pass whichever dataset we want - def norm_check = (norm == "log_cp10k" && pref == "counts") || norm == pref - def method_check = !state.method_ids || state.method_ids.contains(comp.config.name) - - method_check && norm_check - }, - - // define a new 'id' by appending the method name to the dataset id + components: segm_ass_methods, id: { id, state, comp -> - id + "." + comp.config.name + id + "/ass_" + comp.name }, - - // use 'fromState' to fetch the arguments the component requires from the overall state - fromState: { id, state, comp -> - def new_args = [ - input_train: state.input_train, - input_test: state.input_test + fromState: [ + input_ist: "input_sp", + input_scrnaseq: "input_sc", + input_segmentation: "output_segmentation" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "assignment", + component_id: comp.name, + run_id: id + ]], + output_assignment: out_dict.output ] - if (comp.config.info.type == "control_method") { - new_args.input_solution = state.input_solution - } - new_args + } + ) + + /**************************************** + * RUN DIRECT ASSIGNMENT * + ****************************************/ + + // TODO: implement this when direct assignment methods are added + direct_ass_methods = [] + direct_ass_ch = Channel.empty() + // direct_ass_ch = init_ch + // | runEach( + // components: segm_ass_methods, + // id: { id, state, comp -> + // id + "/ass_" + comp.name + // }, + // fromState: [ + // input_ist: "input_sp", + // input_scrnaseq: "input_sc" + // ], + // toState: { id, out_dict, state, comp -> + // state + [ + // steps: state.steps + [[ + // type: "assignment", + // run_id: id, + // component_id: comp.name, + // input_state: state, + // output_dict: out_dict + // ]], + // output_assignment: out_dict.output + // ] + // } + // ) + + /**************************************** + * COMBINE ASSIGNMENT * + ****************************************/ + assignment_ch = segm_ass_ch.mix(direct_ass_ch) + + /**************************************** + * COUNT AGGREGATION * + ****************************************/ + count_aggr_methods = [ + basic_count_aggregation + ] + count_aggr_ch = assignment_ch + | runEach( + components: count_aggr_methods, + id: { id, state, comp -> + id + "/aggr_" + comp.name }, - - // use 'toState' to publish that component's outputs to the overall state - toState: { id, output, state, comp -> + fromState: [ + input: "output_assignment" + ], + toState: { id, out_dict, state, comp -> state + [ - method_id: comp.config.name, - method_output: output.output + steps: state.steps + [[ + type: "count_aggregation", + component_id: comp.name, + run_id: id + ]], + output_count_aggregation: out_dict.output ] } ) - // run all metrics + + /**************************************** + * COUNT AGGREGATION * + ****************************************/ + qc_filter_methods = [ + basic_qc_filter + ] + qc_filter_ch = count_aggr_ch | runEach( - components: metrics, + components: qc_filter_methods, id: { id, state, comp -> - id + "." + comp.config.name + id + "/qc_filter_" + comp.name }, - // use 'fromState' to fetch the arguments the component requires from the overall state fromState: [ - input_solution: "input_solution", - input_prediction: "method_output" + input: "output_count_aggregation" ], - // use 'toState' to publish that component's outputs to the overall state - toState: { id, output, state, comp -> + toState: { id, out_dict, state, comp -> state + [ - metric_id: comp.config.name, - metric_output: output.output + steps: state.steps + [[ + type: "qc_filter", + component_id: comp.name, + run_id: id + ]], + output_qc_filter: out_dict.output ] } ) - /****************************** - * GENERATE OUTPUT YAML FILES * - ******************************/ - // TODO: can we store everything below in a separate helper function? - // extract the dataset metadata - dataset_meta_ch = dataset_ch - // only keep one of the normalization methods - | filter{ id, state -> - state.dataset_uns.normalization_id == "log_cp10k" - } - | joinStates { ids, states -> - // store the dataset metadata in a file - def dataset_uns = states.collect{state -> - def uns = state.dataset_uns.clone() - uns.remove("normalization_id") - uns + /**************************************** + * VOLUME CALCULATION * + ****************************************/ + cell_vol_methods = [ + alpha_shapes + ] + cell_vol_ch = qc_filter_ch + | runEach( + components: cell_vol_methods, + id: { id, state, comp -> + id + "/cell_vol_" + comp.name + }, + fromState: [ + input: "output_assignment" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "calculate_cell_volume", + component_id: comp.name, + run_id: id + ]], + output_cell_volume: out_dict.output + ] } - def dataset_uns_yaml_blob = toYamlBlob(dataset_uns) - def dataset_uns_file = tempFile("dataset_uns.yaml") - dataset_uns_file.write(dataset_uns_yaml_blob) + ) - ["output", [output_dataset_info: dataset_uns_file]] - } + /**************************************** + * NORMALIZATION BY VOLUME * + ****************************************/ + vol_norm_methods = [ + normalize_by_volume + ] + vol_norm_ch = cell_vol_ch + | runEach( + components: vol_norm_methods, + id: { id, state, comp -> + id + "/norm_" + comp.name + }, + fromState: [ + input_spatial_aggregated_counts: "output_count_aggregation", + input_cell_volumes: "output_cell_volume" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "normalization", + component_id: comp.name, + run_id: id + ]], + output_normalization: out_dict.output + ] + } + ) + + + /**************************************** + * DIRECT NORMALIZATION * + ****************************************/ + + // TODO: implement this when direct normalization methods are added + direct_norm_methods = [] + direct_norm_ch = Channel.empty() + // direct_norm_ch = qc_filter_ch + // | runEach( + // components: direct_norm_methods, + // id: { id, state, comp -> + // id + "/norm_" + comp.name + // }, + // fromState: [ + // input: "output_count_aggregation" + // ], + // toState: { id, out_dict, state, comp -> + // state + [ + // steps: state.steps + [[ + // type: "normalization", + // run_id: id, + // component_id: comp.name, + // input_state: state, + // output_dict: out_dict + // ]], + // output_normalization: out_dict.output + // ] + // } + // ) + + /**************************************** + * COMBINE NORMALIZATION * + ****************************************/ + normalization_ch = vol_norm_ch.mix(direct_norm_ch) + + + /**************************************** + * CELL TYPE ANNOTATION * + ****************************************/ + cta_methods = [ + ssam + ] + cta_ch = normalization_ch + | runEach( + components: cta_methods, + id: { id, state, comp -> + id + "/cta_" + comp.name + }, + fromState: [ + input_spatial_normalized_counts: "output_normalization", + input_transcript_assignments: "output_assignment", + input_scrnaseq_reference: "input_sc" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "cell_type_assignment", + component_id: comp.name, + run_id: id + ]], + output_cta: out_dict.output + ] + } + ) + + /**************************************** + * EXPRESSION CORRECTION * + ****************************************/ + expr_corr_methods = [ + gene_efficiency_correction + ] + expr_corr_ch = cta_ch + | runEach( + components: expr_corr_methods, + id: { id, state, comp -> + id + "/corr_" + comp.name + }, + fromState: [ + input_spatial_with_cell_types: "output_cta", + input_scrnaseq_reference: "input_sc" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "expression_correction", + component_id: comp.name, + run_id: id + ]], + output_correction: out_dict.output + ] + } + ) - output_ch = score_ch + /**************************************** + * METRICS * + ****************************************/ + metrics = [ + similarity + ] + metric_ch = expr_corr_ch + | runEach( + components: metrics, + id: { id, state, comp -> + id + "/metric_" + comp.name + }, + fromState: [ + input: "output_correction", + input_qc_col: "output_qc_filter", + input_sc: "input_sc", + input_transcript_assignments: "output_assignment" + ], + toState: { id, out_dict, state, comp -> + state + [ + steps: state.steps + [[ + type: "metric", + component_id: comp.name, + run_id: id + ]], + output_metric: out_dict.output + ] + } + ) // extract the scores - | extract_metadata.run( - key: "extract_scores", - fromState: [input: "metric_output"], + | extract_uns_metadata.run( + key: "extract_uns_scores", + fromState: [input: "output_metric"], toState: { id, output, state -> state + [ score_uns: readYaml(output.output).uns @@ -148,42 +369,79 @@ workflow run_wf { ) | joinStates { ids, states -> - // store the method configs in a file - def method_configs = methods.collect{it.config} - def method_configs_yaml_blob = toYamlBlob(method_configs) - def method_configs_file = tempFile("method_configs.yaml") - method_configs_file.write(method_configs_yaml_blob) - - // store the metric configs in a file - def metric_configs = metrics.collect{it.config} - def metric_configs_yaml_blob = toYamlBlob(metric_configs) - def metric_configs_file = tempFile("metric_configs.yaml") - metric_configs_file.write(metric_configs_yaml_blob) - - def viash_file = meta.resources_dir.resolve("_viash.yaml") - def viash_file_content = toYamlBlob(readYaml(viash_file).info) - def task_info_file = tempFile("task_info.yaml") - task_info_file.write(viash_file_content) + // TODO: determine what to store in the score_uns file // store the scores in a file - def score_uns = states.collect{it.score_uns} + def score_uns = states.collect{it.score_uns + [steps: it.steps]} def score_uns_yaml_blob = toYamlBlob(score_uns) def score_uns_file = tempFile("score_uns.yaml") score_uns_file.write(score_uns_yaml_blob) + ["output", [output_scores: score_uns_file]] + } + + + /****************************** + * GENERATE OUTPUT YAML FILES * + ******************************/ + // extract the dataset metadata + meta_ch = input_ch + + // store join id + | map{ id, state -> + [id, state + ["_meta": [join_id: id]]] + } + + // only keep one of the normalization methods + | joinStates { ids, states -> + // TODO: determine what to store in the dataset_uns file + + def dataset_uns_file = tempFile("dataset_uns.yaml") + dataset_uns_file.write("") + def method_configs_file = tempFile("method_configs.yaml") + method_configs_file.write("") + def metric_configs_file = tempFile("metric_configs.yaml") + metric_configs_file.write("") + + // // store the dataset metadata in a file + // def dataset_uns = states.collect{state -> + // def uns = state.dataset_uns.clone() + // uns.remove("normalization_id") + // uns + // } + // def dataset_uns_yaml_blob = toYamlBlob(dataset_uns) + // def dataset_uns_file = tempFile("dataset_uns.yaml") + // dataset_uns_file.write(dataset_uns_yaml_blob) + + // // store the method configs in a file + // def method_configs = methods.collect{it.config} + // def method_configs_yaml_blob = toYamlBlob(method_configs) + // def method_configs_file = tempFile("method_configs.yaml") + // method_configs_file.write(method_configs_yaml_blob) + + // // store the metric configs in a file + // def metric_configs = metrics.collect{it.config} + // def metric_configs_yaml_blob = toYamlBlob(metric_configs) + // def metric_configs_file = tempFile("metric_configs.yaml") + // metric_configs_file.write(metric_configs_yaml_blob) + + // retrieve task info + def viash_file = meta.resources_dir.resolve("_viash.yaml") + + // create state def new_state = [ + output_dataset_info: dataset_uns_file, output_method_configs: method_configs_file, output_metric_configs: metric_configs_file, - output_task_info: task_info_file, - output_scores: score_uns_file, + output_task_info: viash_file, _meta: states[0]._meta ] ["output", new_state] } - // merge all of the output data - | mix(dataset_meta_ch) + output_ch = metric_ch + | mix(meta_ch) | joinStates{ ids, states -> def mergedStates = states.inject([:]) { acc, m -> acc + m } [ids[0], mergedStates] @@ -192,120 +450,3 @@ workflow run_wf { emit: output_ch } - -// temp fix for rename_keys typo - -def findStatesTemp(Map params, Map config) { - def auto_config = deepClone(config) - def auto_params = deepClone(params) - - auto_config = auto_config.clone() - // override arguments - auto_config.argument_groups = [] - auto_config.arguments = [ - [ - type: "string", - name: "--id", - description: "A dummy identifier", - required: false - ], - [ - type: "file", - name: "--input_states", - example: "/path/to/input/directory/**/state.yaml", - description: "Path to input directory containing the datasets to be integrated.", - required: true, - multiple: true, - multiple_sep: ";" - ], - [ - type: "string", - name: "--filter", - example: "foo/.*/state.yaml", - description: "Regex to filter state files by path.", - required: false - ], - // to do: make this a yaml blob? - [ - type: "string", - name: "--rename_keys", - example: ["newKey1:oldKey1", "newKey2:oldKey2"], - description: "Rename keys in the detected input files. This is useful if the input files do not match the set of input arguments of the workflow.", - required: false, - multiple: true, - multiple_sep: ";" - ], - [ - type: "string", - name: "--settings", - example: '{"output_dataset": "dataset.h5ad", "k": 10}', - description: "Global arguments as a JSON glob to be passed to all components.", - required: false - ] - ] - if (!(auto_params.containsKey("id"))) { - auto_params["id"] = "auto" - } - - // run auto config through processConfig once more - auto_config = processConfig(auto_config) - - workflow findStatesTempWf { - helpMessage(auto_config) - - output_ch = - channelFromParams(auto_params, auto_config) - | flatMap { autoId, args -> - - def globalSettings = args.settings ? readYamlBlob(args.settings) : [:] - - // look for state files in input dir - def stateFiles = args.input_states - - // filter state files by regex - if (args.filter) { - stateFiles = stateFiles.findAll{ stateFile -> - def stateFileStr = stateFile.toString() - def matcher = stateFileStr =~ args.filter - matcher.matches()} - } - - // read in states - def states = stateFiles.collect { stateFile -> - def state_ = readTaggedYaml(stateFile) - [state_.id, state_] - } - - // construct renameMap - if (args.rename_keys) { - def renameMap = args.rename_keys.collectEntries{renameString -> - def split = renameString.split(":") - assert split.size() == 2: "Argument 'rename_keys' should be of the form 'newKey:oldKey;newKey:oldKey'" - split - } - - // rename keys in state, only let states through which have all keys - // also add global settings - states = states.collectMany{id, state -> - def newState = [:] - - for (key in renameMap.keySet()) { - def origKey = renameMap[key] - if (!(state.containsKey(origKey))) { - return [] - } - newState[key] = state[origKey] - } - - [[id, globalSettings + newState]] - } - } - - states - } - emit: - output_ch - } - - return findStatesTempWf -} \ No newline at end of file