From 6510489b0b7f2641b829d140c810b70317a5fbf0 Mon Sep 17 00:00:00 2001 From: Kai Waldrant Date: Thu, 11 Jul 2024 11:24:25 +0200 Subject: [PATCH] add run benchmark workflow [WIP] --- src/workflows/run_benchmark/main.nf | 308 ++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 src/workflows/run_benchmark/main.nf diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf new file mode 100644 index 0000000..4d036c5 --- /dev/null +++ b/src/workflows/run_benchmark/main.nf @@ -0,0 +1,308 @@ +workflow auto { + findStatesTempWf(params, meta.config) + | meta.workflow.run( + auto: [publish: "state"] + ) +} + +workflow run_wf { + take: + input_ch + + main: + + // construct list of methods + methods = [ + true_labels, + logistic_regression + ] + + // construct list of metrics + metrics = [ + accuracy + ] + + /**************************** + * EXTRACT DATASET METADATA * + ****************************/ + dataset_ch = input_ch + // store join id + | map{ id, state -> + [id, state + ["_meta": [join_id: id]]] + } + + // extract the dataset metadata + | extract_metadata.run( + fromState: [input: "input_solution"], + toState: { id, output, state -> + state + [ + dataset_uns: readYaml(output.output).uns + ] + } + ) + + /*************************** + * RUN METHODS AND METRICS * + ***************************/ + score_ch = dataset_ch + + // run all methods + | runEach( + components: methods, + + // use the 'filter' argument to only run a method on the normalisation the component is asking for + filter: { id, state, comp -> + def norm = state.dataset_uns.normalization_id + def pref = comp.config.info.preferred_normalization + // if the preferred normalisation is none at all, + // we can pass whichever dataset we want + def norm_check = (norm == "log_cp10k" && pref == "counts") || norm == pref + def method_check = !state.method_ids || state.method_ids.contains(comp.config.name) + + method_check && norm_check + }, + + // define a new 'id' by appending the method name to the dataset id + id: { id, state, comp -> + id + "." + comp.config.name + }, + + // use 'fromState' to fetch the arguments the component requires from the overall state + fromState: { id, state, comp -> + def new_args = [ + input_train: state.input_train, + input_test: state.input_test + ] + if (comp.config.info.type == "control_method") { + new_args.input_solution = state.input_solution + } + new_args + }, + + // use 'toState' to publish that component's outputs to the overall state + toState: { id, output, state, comp -> + state + [ + method_id: comp.config.name, + method_output: output.output + ] + } + ) + + // run all metrics + | runEach( + components: metrics, + id: { id, state, comp -> + id + "." + comp.config.name + }, + // use 'fromState' to fetch the arguments the component requires from the overall state + fromState: [ + input_solution: "input_solution", + input_prediction: "method_output" + ], + // use 'toState' to publish that component's outputs to the overall state + toState: { id, output, state, comp -> + state + [ + metric_id: comp.config.name, + metric_output: output.output + ] + } + ) + + + /****************************** + * GENERATE OUTPUT YAML FILES * + ******************************/ + // TODO: can we store everything below in a separate helper function? + + // extract the dataset metadata + dataset_meta_ch = dataset_ch + // only keep one of the normalization methods + | filter{ id, state -> + state.dataset_uns.normalization_id == "log_cp10k" + } + | joinStates { ids, states -> + // store the dataset metadata in a file + def dataset_uns = states.collect{state -> + def uns = state.dataset_uns.clone() + uns.remove("normalization_id") + uns + } + def dataset_uns_yaml_blob = toYamlBlob(dataset_uns) + def dataset_uns_file = tempFile("dataset_uns.yaml") + dataset_uns_file.write(dataset_uns_yaml_blob) + + ["output", [output_dataset_info: dataset_uns_file]] + } + + output_ch = score_ch + + // extract the scores + | extract_metadata.run( + key: "extract_scores", + fromState: [input: "metric_output"], + toState: { id, output, state -> + state + [ + score_uns: readYaml(output.output).uns + ] + } + ) + + | joinStates { ids, states -> + // store the method configs in a file + def method_configs = methods.collect{it.config} + def method_configs_yaml_blob = toYamlBlob(method_configs) + def method_configs_file = tempFile("method_configs.yaml") + method_configs_file.write(method_configs_yaml_blob) + + // store the metric configs in a file + def metric_configs = metrics.collect{it.config} + def metric_configs_yaml_blob = toYamlBlob(metric_configs) + def metric_configs_file = tempFile("metric_configs.yaml") + metric_configs_file.write(metric_configs_yaml_blob) + + def task_info_file = meta.resources_dir.resolve("_viash.yaml") + + // store the scores in a file + def score_uns = states.collect{it.score_uns} + def score_uns_yaml_blob = toYamlBlob(score_uns) + def score_uns_file = tempFile("score_uns.yaml") + score_uns_file.write(score_uns_yaml_blob) + + def new_state = [ + output_method_configs: method_configs_file, + output_metric_configs: metric_configs_file, + output_task_info: task_info_file.info, + output_scores: score_uns_file, + _meta: states[0]._meta + ] + + ["output", new_state] + } + + // merge all of the output data + | mix(dataset_meta_ch) + | joinStates{ ids, states -> + def mergedStates = states.inject([:]) { acc, m -> acc + m } + [ids[0], mergedStates] + } + + emit: + output_ch +} + +// temp fix for rename_keys typo + +def findStatesTemp(Map params, Map config) { + def auto_config = deepClone(config) + def auto_params = deepClone(params) + + auto_config = auto_config.clone() + // override arguments + auto_config.argument_groups = [] + auto_config.arguments = [ + [ + type: "string", + name: "--id", + description: "A dummy identifier", + required: false + ], + [ + type: "file", + name: "--input_states", + example: "/path/to/input/directory/**/state.yaml", + description: "Path to input directory containing the datasets to be integrated.", + required: true, + multiple: true, + multiple_sep: ";" + ], + [ + type: "string", + name: "--filter", + example: "foo/.*/state.yaml", + description: "Regex to filter state files by path.", + required: false + ], + // to do: make this a yaml blob? + [ + type: "string", + name: "--rename_keys", + example: ["newKey1:oldKey1", "newKey2:oldKey2"], + description: "Rename keys in the detected input files. This is useful if the input files do not match the set of input arguments of the workflow.", + required: false, + multiple: true, + multiple_sep: ";" + ], + [ + type: "string", + name: "--settings", + example: '{"output_dataset": "dataset.h5ad", "k": 10}', + description: "Global arguments as a JSON glob to be passed to all components.", + required: false + ] + ] + if (!(auto_params.containsKey("id"))) { + auto_params["id"] = "auto" + } + + // run auto config through processConfig once more + auto_config = processConfig(auto_config) + + workflow findStatesTempWf { + helpMessage(auto_config) + + output_ch = + channelFromParams(auto_params, auto_config) + | flatMap { autoId, args -> + + def globalSettings = args.settings ? readYamlBlob(args.settings) : [:] + + // look for state files in input dir + def stateFiles = args.input_states + + // filter state files by regex + if (args.filter) { + stateFiles = stateFiles.findAll{ stateFile -> + def stateFileStr = stateFile.toString() + def matcher = stateFileStr =~ args.filter + matcher.matches()} + } + + // read in states + def states = stateFiles.collect { stateFile -> + def state_ = readTaggedYaml(stateFile) + [state_.id, state_] + } + + // construct renameMap + if (args.rename_keys) { + def renameMap = args.rename_keys.collectEntries{renameString -> + def split = renameString.split(":") + assert split.size() == 2: "Argument 'rename_keys' should be of the form 'newKey:oldKey;newKey:oldKey'" + split + } + + // rename keys in state, only let states through which have all keys + // also add global settings + states = states.collectMany{id, state -> + def newState = [:] + + for (key in renameMap.keySet()) { + def origKey = renameMap[key] + if (!(state.containsKey(origKey))) { + return [] + } + newState[key] = state[origKey] + } + + [[id, globalSettings + newState]] + } + } + + states + } + emit: + output_ch + } + + return findStatesTempWf +} \ No newline at end of file