-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmatch_tbl_to_model_task.R
49 lines (48 loc) · 1.73 KB
/
match_tbl_to_model_task.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#' Match model output `tbl` data to their model tasks in `config_tasks`.
#'
#' Split and match model output `tbl` data to their corresponding model tasks in
#' `config_tasks`. Useful for performing model task specific checks on model output.
#' For v3 samples, the `output_type_id` column is set to `NA` for `sample` outputs.
#' @inheritParams expand_model_out_grid
#' @inheritParams check_tbl_colnames
#'
#' @return A list containing a `tbl_df` of model output data matched to a model
#' task with one element per round model task.
#' @export
#'
#' @examples
#' hub_path <- system.file("testhubs/samples", package = "hubValidations")
#' tbl <- read_model_out_file(
#' file_path = "flu-base/2022-10-22-flu-base.csv",
#' hub_path, coerce_types = "chr"
#' )
#' config_tasks <- hubUtils::read_config(hub_path, "tasks")
#' match_tbl_to_model_task(tbl, config_tasks, round_id = "2022-10-22")
#' match_tbl_to_model_task(tbl, config_tasks,
#' round_id = "2022-10-22",
#' output_types = "sample"
#' )
match_tbl_to_model_task <- function(tbl, config_tasks, round_id,
output_types = NULL, derived_task_ids = NULL,
all_character = TRUE) {
join_cols <- names(tbl)[names(tbl) != "value"]
if (hubUtils::is_v3_config(config_tasks)) {
tbl[tbl$output_type == "sample", "output_type_id"] <- NA
}
expand_model_out_grid(
config_tasks,
round_id = round_id,
required_vals_only = FALSE,
all_character = TRUE,
as_arrow_table = FALSE,
bind_model_tasks = FALSE,
output_types = output_types,
derived_task_ids = derived_task_ids
) %>%
purrr::map(\(.x) {
if (nrow(.x) == 0L) {
return(NULL)
}
dplyr::inner_join(.x, tbl, by = join_cols)
})
}