Skip to content

Commit

Permalink
fix as_covid_hub_forecasts check errors
Browse files Browse the repository at this point in the history
  • Loading branch information
lshandross committed Nov 1, 2023
1 parent 1978f72 commit a0c967e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 39 deletions.
19 changes: 10 additions & 9 deletions R/as_covid_hub_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#' @export
#'
#' @examples
#' library(dplyr)
#' forecasts <- load_forecasts(
#' models = c("COVIDhub-ensemble", "UMass-MechBayes"),
#' dates = "2020-12-14",
Expand All @@ -63,16 +64,16 @@
#' source = "zoltar"
#' )
#' altered_forecasts <- forecasts |> # Alter forecasts to not be CovidHub format
# ' rename(model_id=model, output_type=type, output_type_id=quantile) |>
#' mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |>
#' select(-temporal_resolution)
#' dplyr::rename(model_id=model, output_type=type, output_type_id=quantile) |>
#' dplyr::mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |>
#' dplyr::select(-temporal_resolution)
#' formatted_forecasts <- as_covid_hub_forecasts(
#' altered_forecasts,
#' target_col="target_variable",
#' temp_res_col=NULL
#' ) |>
#' mutate(horizon=as.character(horizon))
#' expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value))
#' dplyr::mutate(horizon=as.character(horizon))
#' testthat::expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value))

as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",
reference_date_col="forecast_date",
Expand All @@ -87,7 +88,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",
provided_names <- c(model_id_col, reference_date_col, location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col, temp_res_col, target_end_date_col)
mandatory_cols <- list(location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col)

if (any(map_lgl(mandatory_cols, is.null))) {
if (any(purrr::map_lgl(mandatory_cols, is.null))) {
stop("You must provide the names of columns with location, horizon, target, output_type, output_type_id, and value information.")
}

Expand Down Expand Up @@ -119,7 +120,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",
if (is.null(temp_res_col)) {
model_outputs <- model_outputs |>
dplyr::rename(target = target_variable) |>
mutate(target = ifelse(
dplyr::mutate(target = ifelse(
stringr::str_detect(target, "ahead"),
stringr::str_replace(target, "ahead", "") |> stringr::str_squish(),
target)) |>
Expand All @@ -128,7 +129,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",

if (is.null(reference_date_col)) {
model_outputs <- model_outputs |>
dplyr::mutate(forecast_date=case_when(
dplyr::mutate(forecast_date=dplyr::case_when(
temporal_resolution %in% c("d", "day") ~ target_end_date - lubridate::days(horizon),
temporal_resolution %in% c("w", "wk", "week") ~ target_end_date - lubridate::weeks(horizon),
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ target_end_date %m-% months(horizon),
Expand All @@ -139,7 +140,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",

if (is.null(target_end_date_col)) {
model_outputs <- model_outputs |>
dplyr::mutate(target_end_date=case_when(
dplyr::mutate(target_end_date=dplyr::case_when(
temporal_resolution %in% c("d", "day") ~ forecast_date + lubridate::days(horizon),
temporal_resolution %in% c("w", "wk", "week") ~ forecast_date + lubridate::weeks(horizon),
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ forecast_date %m+% months(horizon),
Expand Down
10 changes: 6 additions & 4 deletions man/as_covid_hub_forecasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 2 additions & 26 deletions tests/testthat/test-as_covid_hub_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ simple_test_forecasts <- data.frame(
stringsAsFactors = FALSE
)

as_covid_hub_forecasts(model_outputs, model_id_col = "model_id",
reference_date_col="forecast_date",
location_col="location",
horizon_col="horizon", target_col="target",
output_type_col="output_type",
output_type_id_col="output_type_id",
value_col="value",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date")

test_that("Not providing the names of all mandatory columns throws an error", {
simple_test_forecasts |>
as_covid_hub_forecasts(horizon=NULL, target_col="target_variable",
Expand Down Expand Up @@ -83,20 +73,6 @@ test_that("The resulting temporal resolution and target variable columns are cor
expect_equal(test_forecasts_hub$temporal_resolution, actual_temporal_resolution)
expect_equal(test_forecasts_hub$target_variable, actual_target_variable)
})

test_forecasts <- data.frame(
model_id = c("source1", "source2"),
forecast_date = c(ymd(20200101), ymd(20200101)),
location = c("01", "01"),
horizon = c("1", "1"),
temporal_resolution = c("wk", "wk"),
target_variable = c("inc death", "inc death"),
target_end_date = c(ymd(20200108), ymd(20200108)),
output_type = c("point", "point"),
output_type_id = c(NA, NA),
value = c(3, 4),
stringsAsFactors = FALSE
)

test_that("Reference dates are correctly calculated if not provided", {
actual_reference_date = c(ymd(20200101), ymd(20200101))
Expand All @@ -106,7 +82,7 @@ test_that("Reference dates are correctly calculated if not provided", {
temp_res_col="temporal_resolution",
reference_date=NULL,
target_end_date_col="target_end_date") |>
pull(forecast_date) |>
dplyr::pull(forecast_date) |>
expect_equal(actual_reference_date)
})

Expand All @@ -118,7 +94,7 @@ test_that("Target end dates are correctly calculated if not provided", {
temp_res_col="temporal_resolution",
reference_date="forecast_date",
target_end_date_col=NULL) |>
pull(target_end_date) |>
dplyr::pull(target_end_date) |>
expect_equal(actual_target_end_date)
})

Expand Down

0 comments on commit a0c967e

Please sign in to comment.