Skip to content

Commit

Permalink
add run benchmark workflow [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiWaldrant committed Jul 11, 2024
1 parent 06d1da3 commit 6510489
Showing 1 changed file with 308 additions and 0 deletions.
308 changes: 308 additions & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
workflow auto {
findStatesTempWf(params, meta.config)
| meta.workflow.run(
auto: [publish: "state"]
)
}

workflow run_wf {
take:
input_ch

main:

// construct list of methods
methods = [
true_labels,
logistic_regression
]

// construct list of metrics
metrics = [
accuracy
]

/****************************
* EXTRACT DATASET METADATA *
****************************/
dataset_ch = input_ch
// store join id
| map{ id, state ->
[id, state + ["_meta": [join_id: id]]]
}

// extract the dataset metadata
| extract_metadata.run(
fromState: [input: "input_solution"],
toState: { id, output, state ->
state + [
dataset_uns: readYaml(output.output).uns
]
}
)

/***************************
* RUN METHODS AND METRICS *
***************************/
score_ch = dataset_ch

// run all methods
| runEach(
components: methods,

// use the 'filter' argument to only run a method on the normalisation the component is asking for
filter: { id, state, comp ->
def norm = state.dataset_uns.normalization_id
def pref = comp.config.info.preferred_normalization
// if the preferred normalisation is none at all,
// we can pass whichever dataset we want
def norm_check = (norm == "log_cp10k" && pref == "counts") || norm == pref
def method_check = !state.method_ids || state.method_ids.contains(comp.config.name)

method_check && norm_check
},

// define a new 'id' by appending the method name to the dataset id
id: { id, state, comp ->
id + "." + comp.config.name
},

// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: { id, state, comp ->
def new_args = [
input_train: state.input_train,
input_test: state.input_test
]
if (comp.config.info.type == "control_method") {
new_args.input_solution = state.input_solution
}
new_args
},

// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
state + [
method_id: comp.config.name,
method_output: output.output
]
}
)

// run all metrics
| runEach(
components: metrics,
id: { id, state, comp ->
id + "." + comp.config.name
},
// use 'fromState' to fetch the arguments the component requires from the overall state
fromState: [
input_solution: "input_solution",
input_prediction: "method_output"
],
// use 'toState' to publish that component's outputs to the overall state
toState: { id, output, state, comp ->
state + [
metric_id: comp.config.name,
metric_output: output.output
]
}
)


/******************************
* GENERATE OUTPUT YAML FILES *
******************************/
// TODO: can we store everything below in a separate helper function?

// extract the dataset metadata
dataset_meta_ch = dataset_ch
// only keep one of the normalization methods
| filter{ id, state ->
state.dataset_uns.normalization_id == "log_cp10k"
}
| joinStates { ids, states ->
// store the dataset metadata in a file
def dataset_uns = states.collect{state ->
def uns = state.dataset_uns.clone()
uns.remove("normalization_id")
uns
}
def dataset_uns_yaml_blob = toYamlBlob(dataset_uns)
def dataset_uns_file = tempFile("dataset_uns.yaml")
dataset_uns_file.write(dataset_uns_yaml_blob)

["output", [output_dataset_info: dataset_uns_file]]
}

output_ch = score_ch

// extract the scores
| extract_metadata.run(
key: "extract_scores",
fromState: [input: "metric_output"],
toState: { id, output, state ->
state + [
score_uns: readYaml(output.output).uns
]
}
)

| joinStates { ids, states ->
// store the method configs in a file
def method_configs = methods.collect{it.config}
def method_configs_yaml_blob = toYamlBlob(method_configs)
def method_configs_file = tempFile("method_configs.yaml")
method_configs_file.write(method_configs_yaml_blob)

// store the metric configs in a file
def metric_configs = metrics.collect{it.config}
def metric_configs_yaml_blob = toYamlBlob(metric_configs)
def metric_configs_file = tempFile("metric_configs.yaml")
metric_configs_file.write(metric_configs_yaml_blob)

def task_info_file = meta.resources_dir.resolve("_viash.yaml")

// store the scores in a file
def score_uns = states.collect{it.score_uns}
def score_uns_yaml_blob = toYamlBlob(score_uns)
def score_uns_file = tempFile("score_uns.yaml")
score_uns_file.write(score_uns_yaml_blob)

def new_state = [
output_method_configs: method_configs_file,
output_metric_configs: metric_configs_file,
output_task_info: task_info_file.info,
output_scores: score_uns_file,
_meta: states[0]._meta
]

["output", new_state]
}

// merge all of the output data
| mix(dataset_meta_ch)
| joinStates{ ids, states ->
def mergedStates = states.inject([:]) { acc, m -> acc + m }
[ids[0], mergedStates]
}

emit:
output_ch
}

// temp fix for rename_keys typo

def findStatesTemp(Map params, Map config) {
def auto_config = deepClone(config)
def auto_params = deepClone(params)

auto_config = auto_config.clone()
// override arguments
auto_config.argument_groups = []
auto_config.arguments = [
[
type: "string",
name: "--id",
description: "A dummy identifier",
required: false
],
[
type: "file",
name: "--input_states",
example: "/path/to/input/directory/**/state.yaml",
description: "Path to input directory containing the datasets to be integrated.",
required: true,
multiple: true,
multiple_sep: ";"
],
[
type: "string",
name: "--filter",
example: "foo/.*/state.yaml",
description: "Regex to filter state files by path.",
required: false
],
// to do: make this a yaml blob?
[
type: "string",
name: "--rename_keys",
example: ["newKey1:oldKey1", "newKey2:oldKey2"],
description: "Rename keys in the detected input files. This is useful if the input files do not match the set of input arguments of the workflow.",
required: false,
multiple: true,
multiple_sep: ";"
],
[
type: "string",
name: "--settings",
example: '{"output_dataset": "dataset.h5ad", "k": 10}',
description: "Global arguments as a JSON glob to be passed to all components.",
required: false
]
]
if (!(auto_params.containsKey("id"))) {
auto_params["id"] = "auto"
}

// run auto config through processConfig once more
auto_config = processConfig(auto_config)

workflow findStatesTempWf {
helpMessage(auto_config)

output_ch =
channelFromParams(auto_params, auto_config)
| flatMap { autoId, args ->

def globalSettings = args.settings ? readYamlBlob(args.settings) : [:]

// look for state files in input dir
def stateFiles = args.input_states

// filter state files by regex
if (args.filter) {
stateFiles = stateFiles.findAll{ stateFile ->
def stateFileStr = stateFile.toString()
def matcher = stateFileStr =~ args.filter
matcher.matches()}
}

// read in states
def states = stateFiles.collect { stateFile ->
def state_ = readTaggedYaml(stateFile)
[state_.id, state_]
}

// construct renameMap
if (args.rename_keys) {
def renameMap = args.rename_keys.collectEntries{renameString ->
def split = renameString.split(":")
assert split.size() == 2: "Argument 'rename_keys' should be of the form 'newKey:oldKey;newKey:oldKey'"
split
}

// rename keys in state, only let states through which have all keys
// also add global settings
states = states.collectMany{id, state ->
def newState = [:]

for (key in renameMap.keySet()) {
def origKey = renameMap[key]
if (!(state.containsKey(origKey))) {
return []
}
newState[key] = state[origKey]
}

[[id, globalSettings + newState]]
}
}

states
}
emit:
output_ch
}

return findStatesTempWf
}

0 comments on commit 6510489

Please sign in to comment.