diff --git a/R/as_covid_hub_forecasts.R b/R/as_covid_hub_forecasts.R index 6c58339..55566e0 100644 --- a/R/as_covid_hub_forecasts.R +++ b/R/as_covid_hub_forecasts.R @@ -54,6 +54,7 @@ #' @export #' #' @examples +#' library(dplyr) #' forecasts <- load_forecasts( #' models = c("COVIDhub-ensemble", "UMass-MechBayes"), #' dates = "2020-12-14", @@ -63,16 +64,16 @@ #' source = "zoltar" #' ) #' altered_forecasts <- forecasts |> # Alter forecasts to not be CovidHub format -# ' rename(model_id=model, output_type=type, output_type_id=quantile) |> -#' mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |> -#' select(-temporal_resolution) +#' dplyr::rename(model_id=model, output_type=type, output_type_id=quantile) |> +#' dplyr::mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |> +#' dplyr::select(-temporal_resolution) #' formatted_forecasts <- as_covid_hub_forecasts( #' altered_forecasts, #' target_col="target_variable", #' temp_res_col=NULL #' ) |> -#' mutate(horizon=as.character(horizon)) -#' expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value)) +#' dplyr::mutate(horizon=as.character(horizon)) +#' testthat::expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value)) as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", reference_date_col="forecast_date", @@ -87,7 +88,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", provided_names <- c(model_id_col, reference_date_col, location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col, temp_res_col, target_end_date_col) mandatory_cols <- list(location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col) - if (any(map_lgl(mandatory_cols, is.null))) { + if (any(purrr::map_lgl(mandatory_cols, is.null))) { stop("You must provide the names of columns with location, horizon, target, output_type, output_type_id, and value information.") } @@ -119,7 +120,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", if (is.null(temp_res_col)) { model_outputs <- model_outputs |> dplyr::rename(target = target_variable) |> - mutate(target = ifelse( + dplyr::mutate(target = ifelse( stringr::str_detect(target, "ahead"), stringr::str_replace(target, "ahead", "") |> stringr::str_squish(), target)) |> @@ -128,7 +129,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", if (is.null(reference_date_col)) { model_outputs <- model_outputs |> - dplyr::mutate(forecast_date=case_when( + dplyr::mutate(forecast_date=dplyr::case_when( temporal_resolution %in% c("d", "day") ~ target_end_date - lubridate::days(horizon), temporal_resolution %in% c("w", "wk", "week") ~ target_end_date - lubridate::weeks(horizon), temporal_resolution %in% c("m", "mth", "mnth", "month") ~ target_end_date %m-% months(horizon), @@ -139,7 +140,7 @@ as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", if (is.null(target_end_date_col)) { model_outputs <- model_outputs |> - dplyr::mutate(target_end_date=case_when( + dplyr::mutate(target_end_date=dplyr::case_when( temporal_resolution %in% c("d", "day") ~ forecast_date + lubridate::days(horizon), temporal_resolution %in% c("w", "wk", "week") ~ forecast_date + lubridate::weeks(horizon), temporal_resolution %in% c("m", "mth", "mnth", "month") ~ forecast_date %m+% months(horizon), diff --git a/man/as_covid_hub_forecasts.Rd b/man/as_covid_hub_forecasts.Rd index 7c72852..c5c638c 100644 --- a/man/as_covid_hub_forecasts.Rd +++ b/man/as_covid_hub_forecasts.Rd @@ -91,6 +91,7 @@ package such as \code{score_forecasts()} or \code{plot_forecasts()}. The supplie reference dates, locations, horizons, and targets. } \examples{ +library(dplyr) forecasts <- load_forecasts( models = c("COVIDhub-ensemble", "UMass-MechBayes"), dates = "2020-12-14", @@ -100,13 +101,14 @@ forecasts <- load_forecasts( source = "zoltar" ) altered_forecasts <- forecasts |> # Alter forecasts to not be CovidHub format - mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |> - select(-temporal_resolution) + dplyr::rename(model_id=model, output_type=type, output_type_id=quantile) |> + dplyr::mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |> + dplyr::select(-temporal_resolution) formatted_forecasts <- as_covid_hub_forecasts( altered_forecasts, target_col="target_variable", temp_res_col=NULL ) |> -mutate(horizon=as.character(horizon)) -expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value)) +dplyr::mutate(horizon=as.character(horizon)) +testthat::expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value)) } diff --git a/tests/testthat/test-as_covid_hub_forecasts.R b/tests/testthat/test-as_covid_hub_forecasts.R index bd07218..f2be59e 100644 --- a/tests/testthat/test-as_covid_hub_forecasts.R +++ b/tests/testthat/test-as_covid_hub_forecasts.R @@ -17,16 +17,6 @@ simple_test_forecasts <- data.frame( stringsAsFactors = FALSE ) - as_covid_hub_forecasts(model_outputs, model_id_col = "model_id", - reference_date_col="forecast_date", - location_col="location", - horizon_col="horizon", target_col="target", - output_type_col="output_type", - output_type_id_col="output_type_id", - value_col="value", - temp_res_col="temporal_resolution", - target_end_date_col="target_end_date") - test_that("Not providing the names of all mandatory columns throws an error", { simple_test_forecasts |> as_covid_hub_forecasts(horizon=NULL, target_col="target_variable", @@ -83,20 +73,6 @@ test_that("The resulting temporal resolution and target variable columns are cor expect_equal(test_forecasts_hub$temporal_resolution, actual_temporal_resolution) expect_equal(test_forecasts_hub$target_variable, actual_target_variable) }) - - test_forecasts <- data.frame( - model_id = c("source1", "source2"), - forecast_date = c(ymd(20200101), ymd(20200101)), - location = c("01", "01"), - horizon = c("1", "1"), - temporal_resolution = c("wk", "wk"), - target_variable = c("inc death", "inc death"), - target_end_date = c(ymd(20200108), ymd(20200108)), - output_type = c("point", "point"), - output_type_id = c(NA, NA), - value = c(3, 4), - stringsAsFactors = FALSE - ) test_that("Reference dates are correctly calculated if not provided", { actual_reference_date = c(ymd(20200101), ymd(20200101)) @@ -106,7 +82,7 @@ test_that("Reference dates are correctly calculated if not provided", { temp_res_col="temporal_resolution", reference_date=NULL, target_end_date_col="target_end_date") |> - pull(forecast_date) |> + dplyr::pull(forecast_date) |> expect_equal(actual_reference_date) }) @@ -118,7 +94,7 @@ test_that("Target end dates are correctly calculated if not provided", { temp_res_col="temporal_resolution", reference_date="forecast_date", target_end_date_col=NULL) |> - pull(target_end_date) |> + dplyr::pull(target_end_date) |> expect_equal(actual_target_end_date) })