Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for ascending CDF values with numeric output_type_id cast as character #105

Merged
merged 29 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
98c4252
add support for ascen cdf values w/num out_type_id
zkamvar Aug 12, 2024
3ed727d
dang lintr what I ever do to you?
zkamvar Aug 12, 2024
a7a4be9
Update tests/testthat/test-check_tbl_value_col_ascending.R
zkamvar Aug 12, 2024
8421214
fix ordering to be based on hub config
zkamvar Aug 13, 2024
808e6e6
add arguments in validate_model_data
zkamvar Aug 13, 2024
f551702
update, man
zkamvar Aug 13, 2024
dea1cc8
appease the linter
zkamvar Aug 13, 2024
d3d2937
fix skipped test
zkamvar Aug 13, 2024
defd517
remove :skull: code
zkamvar Aug 13, 2024
c82014d
avoid message in test
zkamvar Nov 7, 2024
cf92b6a
clean up internal doc
zkamvar Nov 7, 2024
c64bf73
appeas lintr
zkamvar Nov 7, 2024
8e69bb8
fix test assumptions; add failure test check
zkamvar Jan 3, 2025
f52f619
move early exit; only create cdf/quantile grid
zkamvar Jan 3, 2025
f3f04cf
oops
zkamvar Jan 3, 2025
d2517dd
add failing test for check of ascending columns
elray1 Jan 7, 2025
c0c356a
simplify order_output_type_ids; clarify object names
zkamvar Jan 7, 2025
95d47af
Update tests/testthat/test-check_tbl_value_col_ascending.R
elray1 Jan 7, 2025
c8b9e67
read in test data using read_model_out_file
zkamvar Jan 7, 2025
b76cf0f
Merge pull request #188 from hubverse-org/elr/ascending_check_test_case
zkamvar Jan 7, 2025
a3eb124
read in model output using hubvalidation mechanism
zkamvar Jan 7, 2025
3471e4d
add MWE for internal function
zkamvar Jan 7, 2025
7b43e64
update grouping for sorting; do not coerce ref to char
zkamvar Jan 7, 2025
3148f40
removed modification of target in test
zkamvar Jan 7, 2025
afa20db
add TODO note
zkamvar Jan 7, 2025
98b022e
appease linter
zkamvar Jan 7, 2025
93a5eb7
clean up expectations
zkamvar Jan 7, 2025
1591351
add internal doc
zkamvar Jan 7, 2025
ca6eea7
add NEWS item
zkamvar Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 48 additions & 5 deletions R/check_tbl_value_col_ascending.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
#' @inherit check_tbl_colnames params
#' @inherit check_tbl_col_types return
#' @export
check_tbl_value_col_ascending <- function(tbl, file_path) {
if (all(!c("cdf", "quantile") %in% tbl[["output_type"]])) {
check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) {

# Exit early if there are no values to check
no_values_to_check <- all(!c("cdf", "quantile") %in% tbl[["output_type"]])
if (no_values_to_check) {
return(
capture_check_info(
file_path,
Expand All @@ -22,8 +25,23 @@ check_tbl_value_col_ascending <- function(tbl, file_path) {
)
}

output_type_tbl <- split(tbl, tbl[["output_type"]])[c("cdf", "quantile")] %>%
purrr::compact()
# create a model output table subset to only the CDF and or quantile values
# regardless of whether they are optional or required
config_tasks <- hubUtils::read_config(hub_path, "tasks")
round_output_types <- get_round_output_type_names(config_tasks, round_id)
only_cdf_or_quantile <- intersect(c("cdf", "quantile"), round_output_types)
accepted_vals <- expand_model_out_grid(
config_tasks = config_tasks,
round_id = round_id,
all_character = TRUE,
force_output_types = TRUE,
output_types = only_cdf_or_quantile
Copy link
Member

@annakrystalli annakrystalli Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here it would be much cleaner to use bind_model_tasks = FALSE. This returns a list of tibbles, one for each model task, keeping values associated with different model tasks separate. You can then map over any model tasks, processing them separately. This would cleanly get round the issue @elray1 described of different quantile levels in different model tasks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL about this feature 😮

I'll give it a go!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, given the fact that this already works, I don't think it's strictly necessary to implement. Good to know of the functionality though as it's really useful for separating submission files into model tasks by joining them onto an expanded grid created with bind_model_tasks = FALSE!

)

# FIX for <https://github.com/hubverse-org/hubValidations/issues/78>
# sort the table by config by merging from config ----------------
tbl_sorted <- order_output_type_ids(tbl, accepted_vals, c("cdf", "quantile"))
zkamvar marked this conversation as resolved.
Show resolved Hide resolved
output_type_tbl <- split_cdf_quantile(tbl_sorted)

error_tbl <- purrr::map(
output_type_tbl,
Expand Down Expand Up @@ -57,8 +75,8 @@ check_values_ascending <- function(tbl) {
group_cols <- names(tbl)[!names(tbl) %in% hubUtils::std_colnames]
tbl[["value"]] <- as.numeric(tbl[["value"]])

# group by all of the target columns
check_tbl <- dplyr::group_by(tbl, dplyr::across(dplyr::all_of(group_cols))) %>%
dplyr::arrange(.data$output_type_id, .by_group = TRUE) %>%
dplyr::summarise(non_asc = any(diff(.data[["value"]]) < 0))

if (!any(check_tbl$non_asc)) {
Expand All @@ -72,3 +90,28 @@ check_values_ascending <- function(tbl) {
dplyr::ungroup() %>%
dplyr::mutate(.env$output_type)
}

split_cdf_quantile <- function(tbl) {
split(tbl, tbl[["output_type"]])[c("cdf", "quantile")] %>%
purrr::compact()
}

# Order the output type ids in the order of the config
#
# This extracts the output_type_id from the config-generated table for the
# given types and creates a lookup table that has the types in the right order.
#
# The data from `tbl` is then joined into the lookup table (after being coerced
# to character), which sorts `tbl` in the order of the lookup table.
#
# NOTE: this assumes that the cdf and quantile values in the `tbl` are complete.
zkamvar marked this conversation as resolved.
Show resolved Hide resolved
order_output_type_ids <- function(tbl, config, types = c("cdf", "quantile")) {
zkamvar marked this conversation as resolved.
Show resolved Hide resolved
# step 1: create a lookup table from the config
order_ref <- config[c("output_type", "output_type_id")]
cdf_and_quantile <- order_ref$output_type %in% types
order_ref <- order_ref[cdf_and_quantile, , drop = FALSE]
order_ref <- unique(order_ref)
# step 2: join
tbl$output_type_id <- as.character(tbl$output_type_id)
dplyr::inner_join(order_ref, tbl, by = c("output_type", "output_type_id"))
}
4 changes: 3 additions & 1 deletion R/validate_model_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ validate_model_data <- function(hub_path, file_path, round_id_col = NULL,
checks$value_col_non_desc <- try_check(
check_tbl_value_col_ascending(
tbl,
file_path = file_path
file_path = file_path,
hub_path = hub_path,
round_id = round_id
), file_path
)

Expand Down
14 changes: 13 additions & 1 deletion man/check_tbl_value_col_ascending.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 9 additions & 7 deletions tests/testthat/_snaps/check_tbl_value_col_ascending.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# check_tbl_value_col_ascending works

Code
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Output
<message/check_success>
Message:
Expand All @@ -10,7 +10,7 @@
---

Code
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Output
<message/check_success>
Message:
Expand All @@ -19,7 +19,7 @@
# check_tbl_value_col_ascending works when output type IDs not ordered

Code
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Output
<message/check_success>
Message:
Expand All @@ -28,7 +28,7 @@
# check_tbl_value_col_ascending errors correctly

Code
str(check_tbl_value_col_ascending(tbl, file_path))
str(check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id))
Output
List of 7
$ message : chr "Values in `value` column are not non-decreasing as output_type_ids increase for all unique task ID\n value/o"| __truncated__
Expand All @@ -48,7 +48,8 @@
---

Code
str(check_tbl_value_col_ascending(tbl_error, file_path))
str(check_tbl_value_col_ascending(tbl_error, file_path, hub_path, file_meta$
round_id))
Output
List of 7
$ message : chr "Values in `value` column are not non-decreasing as output_type_ids increase for all unique task ID\n value/o"| __truncated__
Expand All @@ -68,7 +69,8 @@
---

Code
str(check_tbl_value_col_ascending(rbind(tbl, tbl_error), file_path))
str(check_tbl_value_col_ascending(rbind(tbl, tbl_error), file_path, hub_path,
file_meta$round_id))
Output
List of 7
$ message : chr "Values in `value` column are not non-decreasing as output_type_ids increase for all unique task ID\n value/o"| __truncated__
Expand All @@ -88,7 +90,7 @@
# check_tbl_value_col_ascending skips correctly

Code
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Output
<message/check_info>
Message:
Expand Down
109 changes: 101 additions & 8 deletions tests/testthat/test-check_tbl_value_col_ascending.R
Original file line number Diff line number Diff line change
@@ -1,60 +1,69 @@
test_that("check_tbl_value_col_ascending works", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)

expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
)

hub_path <- system.file("testhubs/flusight", package = "hubUtils")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
file_meta <- parse_file_name(file_path)

tbl <- hubValidations::read_model_out_file(file_path, hub_path)

expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
)
})

test_that("check_tbl_value_col_ascending works when output type IDs not ordered", {
hub_path <- test_path("testdata/hub-unordered/")
tbl <- arrow::read_csv_arrow(
test_path("testdata/files/2024-01-10-ISI-NotOrdered.csv")
fs::path(hub_path, "model-output/2024-01-10-ISI-NotOrdered.csv")
) %>%
hubData::coerce_to_character()
file_path <- "ISI-NotOrdered/2024-01-10-ISI-NotOrdered.csv"
file_meta <- parse_file_name(file_path)
expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
)
})

test_that("check_tbl_value_col_ascending errors correctly", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)

tbl$value[c(1, 10)] <- 150

expect_snapshot(
str(check_tbl_value_col_ascending(tbl, file_path))
str(check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id))
)

hub_path <- system.file("testhubs/flusight", package = "hubUtils")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl_error <- tbl
tbl_error$target <- "wk ahead inc covid hosp"
tbl_error$value[1] <- 800

expect_snapshot(
str(
check_tbl_value_col_ascending(tbl_error, file_path)
check_tbl_value_col_ascending(tbl_error, file_path, hub_path, file_meta$round_id)
)
)
expect_snapshot(
str(
check_tbl_value_col_ascending(
rbind(tbl, tbl_error),
file_path
file_path,
hub_path,
file_meta$round_id
)
)
)
Expand All @@ -63,10 +72,94 @@ test_that("check_tbl_value_col_ascending errors correctly", {
test_that("check_tbl_value_col_ascending skips correctly", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- tbl[tbl$output_type == "mean", ]

expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path)
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
)
})


test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn't naturally sort", {
# In this situaton, I am duplicating the simple testhub and modifying it in
# one way:
#
# I am replacing the `quantile` model task with `cdf` and adding a cumulative
# sum so that we can get unsortable numbers.
make_unsortable <- function(x) suppressWarnings(x + 1:23)

# Duplicating the simple test hub ---------------------------------------
hub_path <- withr::local_tempdir()
fs::dir_copy(system.file("testhubs/simple", package = "hubValidations"),
hub_path,
overwrite = TRUE
)

# Creating the CFG output -----------------------------------------------
cfg <- hubUtils::read_config(hub_path, "tasks")
outputs <- cfg$rounds[[1]]$model_tasks[[1]]$output_type
outputs$cdf <- outputs$quantile
outputs$quantile <- NULL
otid <- outputs$cdf$output_type_id$required
# making the CDF range from 1.01 to 23.99 so that we can distinguish failures
# with character sorting.
outputs$cdf$output_type_id$required <- make_unsortable(otid)
cfg$rounds[[1]]$model_tasks[[1]]$output_type <- outputs
jsonlite::toJSON(cfg) %>%
jsonlite::prettify() %>%
writeLines(fs::path(hub_path, "hub-config", "tasks.json"))

# Updating the data to match the config --------------------------------
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
convert_to_cdf <- function(x) {
ifelse(x == "quantile", "cdf", x)
}
tbl <- hubValidations::read_model_out_file(file_path, hub_path) %>%
dplyr::mutate(output_type_id = make_unsortable(.data[["output_type_id"]])) %>%
dplyr::mutate(output_type = convert_to_cdf(.data[["output_type"]]))

# validating when it is sorted -----------------------------------------
res <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
expect_s3_class(res, "check_success")
expect_null(res$error_tbl)

# validating when table rows are randomly ordered ----------------------
# In this check, the values still ascend with the output type ID, despite
# the rows being unordered.
res_unordered <- check_tbl_value_col_ascending(
tbl[sample(nrow(tbl)), ],
file_path,
hub_path,
file_meta$round_id
)
expect_s3_class(res_unordered, "check_success")
expect_null(res_unordered$error_tbl)

# mismatched values will result in an error ----------------------------
# if we switch the first two values, this will mean that they are no longer
# non-descending.
tbl_with_err <- tbl
tbl_with_err$value[1:2] <- tbl_with_err$value[2:1]
res_with_err <- check_tbl_value_col_ascending(
tbl_with_err,
file_path,
hub_path,
file_meta$round_id
)
expected <- tibble::tibble(
origin_date = as.Date("2022-10-08"),
target = "wk inc flu hosp",
horizon = 1,
location = "US",
output_type = "cdf"
)
actual <- res_with_err$error_tbl

expect_s3_class(res_with_err, "check_failure")
expect_s3_class(actual, "data.frame")
expect_equal(nrow(actual), 1)
expect_equal(actual, expected, ignore_attr = TRUE)
})
Loading
Loading