Skip to content

Commit

Permalink
Merge pull request #398 from reichlab/as_covid_hub_forecasts
Browse files Browse the repository at this point in the history
As covid hub forecasts function
  • Loading branch information
lshandross authored Nov 3, 2023
2 parents 7258bc1 + 953c1cd commit 8bed140
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 0 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export(aggregate_to_weekly)
export(align_forecasts)
export(as_covid_hub_forecasts)
export(calc_cramers_dist_equal_space)
export(calc_cramers_dist_one_model_pair)
export(calc_cramers_dist_unequal_space)
Expand Down
159 changes: 159 additions & 0 deletions R/as_covid_hub_forecasts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#' Reformat model outputs stored as a `model_output_tbl` class (or similar) to
#' that of a `data.frame` formatted according to standards of the COVID-19
#' Forecasting Hub which can be processed by functions from the `covidHubUtils`
#' package such as `score_forecasts()` or `plot_forecasts()`. The supplied
#' `model_output_tbl` should have columns defining properties akin to
#' reference dates, locations, horizons, and targets.
#'
#' @param model_outputs an object of class `model_output_tbl` with component
#' model outputs (e.g., predictions). Should have columns containing the
#' following information: model name, reference date or target end date,
#' location, horizon, target, temporal resolution*, output type, output
#' type id, and value. Note that the temporal resolution may be included
#' in the target column.
#' @param model_id_col `character` string of the name of the column
#' containing the model name(s) for the forecasts. Defaults to "model_id".
#' Should be set to NULL if no such column exists, in which case a model_id
#' column will be created populated with the value "model_id".
#' @param location_col `character` string of the name of the column
#' containing the locations for the forecasts. Defaults to "location".
#' @param horizon_col `character` string of the name of the column
#' containing the horizons for the forecasts. Defaults to "horizon".
#' @param target_col `character` string of the name of the column
#' containing the targets for the forecasts. Defaults to "target". If
#' `temp_res_col` is NULL, the target column in `model_outputs` is assumed
#' to contain targets of the form "temporal_resolution target" or
#' "temporal_resolution ahead target", such as "wk ahead inc flu hosp"
#' "wk inc flu hosp".
#' @param reference_date_col `character` string of the name of the column
#' containing the reference dates for the forecasts. Defaults to
#' "forecast_date". Should be set to NULL if no such column exists, in which
#' case the column will be created using the following information:
#' horizon, target end date, and temporal resolution.
#' @param target_end_date_col `character` string of the name of the column
#' containing the target end dates for the forecasts. Defaults to
#' "target_end_date". Should be set to NULL if no such column exists, in
#' which case the column will be created using the following information:
#' horizon, forecast date, and temporal resolution.
#' @param output_type_col `character` string of the name of the column
#' containing the output types for the forecasts. Defaults to "output_type".
#' @param output_type_id_col `character` string of the name of the column
#' containing the output type ids for the forecasts. Defaults to
#' "output_type_id".
#' @param value_col `character` string of the name of the column
#' containing the values for the forecasts. Defaults to "value".
#' @param temp_res_col `character` string of the name of the column
#' containing the temporal resolutions for the forecasts. Defaults to
#' "temporal_resolution". Should be set to NULL if no such column exists,
#' in which case the column will be created from the existing target column.
#'
#' @return a `data.frame` of reformatted model outputs that may be fed into
#' any of the `covidHubUtils` functions with 10 total columns: model,
#' forecast_date, location, horizon, temporal_resolution, target_variable,
#' target_end_date, type, quantile, value. Other columns are removed.
#' @export
#'
#' @examples
#' library(dplyr)
#' forecasts <- load_forecasts(
#' models = c("COVIDhub-ensemble", "UMass-MechBayes"),
#' dates = "2020-12-14",
#' date_window_size = 7,
#' locations = c("US"),
#' targets = paste(1:4, "wk ahead inc death"),
#' source = "zoltar"
#' )
#' altered_forecasts <- forecasts |> # Alter forecasts to not be CovidHub format
#' dplyr::rename(model_id=model, output_type=type, output_type_id=quantile) |>
#' dplyr::mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |>
#' dplyr::select(-temporal_resolution)
#' formatted_forecasts <- as_covid_hub_forecasts(
#' altered_forecasts,
#' target_col="target_variable",
#' temp_res_col=NULL
#' ) |>
#' dplyr::mutate(horizon=as.character(horizon))
#' testthat::expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value))

as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id",
reference_date_col="forecast_date",
location_col="location",
horizon_col="horizon", target_col="target",
output_type_col="output_type",
output_type_id_col="output_type_id",
value_col="value",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") {

provided_names <- c(model_id_col, reference_date_col, location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col, temp_res_col, target_end_date_col)
mandatory_cols <- list(location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col)

if (any(purrr::map_lgl(mandatory_cols, is.null))) {
stop("You must provide the names of columns with location, horizon, target, output_type, output_type_id, and value information.")
}

if (isFALSE(all(provided_names %in% names(model_outputs)))) {
stop("Not all provided column names exist in the provided model_outputs.")
}

if (is.null(reference_date_col) & is.null(target_end_date_col)) {
stop("You must provide the name of at least one date column.")
}

if (all(c("mean", "median") %in% unique(model_outputs[[output_type_col]]))){
stop("You may only have one type of point forecast.")
}

if (is.null(model_id_col)) {
warning("No model_id_col provided, creating one automatically.")
model_id_col = "model_id"
model_outputs <- dplyr::mutate(model_outputs, model_id = "model_id", .before = 1)
}

model_outputs <- model_outputs |>
dplyr::rename(model = model_id_col,
type = output_type_col, quantile = output_type_id_col,
forecast_date = reference_date_col, location = location_col,
value = value_col, target_variable = target_col) |>
dplyr::mutate(horizon = as.numeric(horizon))

if (is.null(temp_res_col)) {
model_outputs <- model_outputs |>
dplyr::rename(target = target_variable) |>
dplyr::mutate(target = ifelse(
stringr::str_detect(target, "ahead"),
stringr::str_replace(target, "ahead", "") |> stringr::str_squish(),
target)) |>
tidyr::separate(target, sep=" ", convert=TRUE, into=c("temporal_resolution", "target_variable"), extra="merge")
}

if (is.null(reference_date_col)) {
model_outputs <- model_outputs |>
dplyr::mutate(forecast_date=dplyr::case_when(
temporal_resolution %in% c("d", "day") ~ target_end_date - lubridate::days(horizon),
temporal_resolution %in% c("w", "wk", "week") ~ target_end_date - lubridate::weeks(horizon),
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ target_end_date %m-% months(horizon),
temporal_resolution %in% c("y", "yr", "year") ~ target_end_date - lubridate::years(horizon),
.default = target_end_date),
.before = type)
}

if (is.null(target_end_date_col)) {
model_outputs <- model_outputs |>
dplyr::mutate(target_end_date=dplyr::case_when(
temporal_resolution %in% c("d", "day") ~ forecast_date + lubridate::days(horizon),
temporal_resolution %in% c("w", "wk", "week") ~ forecast_date + lubridate::weeks(horizon),
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ forecast_date %m+% months(horizon),
temporal_resolution %in% c("y", "yr", "year") ~ forecast_date + lubridate::years(horizon),
.default = forecast_date),
.before = type)
}

covid_hub_outputs <- model_outputs |>
dplyr::mutate(horizon=as.character(horizon),
type = ifelse(type %in% c("mean", "median"), "point", type)) |>
dplyr::select(model, forecast_date, location, horizon, temporal_resolution,
target_variable, target_end_date, type, quantile, value)

return (covid_hub_outputs)
}
114 changes: 114 additions & 0 deletions man/as_covid_hub_forecasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

110 changes: 110 additions & 0 deletions tests/testthat/test-as_covid_hub_forecasts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
library(dplyr)
library(lubridate)
library(stringr)

simple_test_forecasts <- data.frame(
model_id = c("source1", "source2"),
forecast_date = c(ymd(20200101), ymd(20200101)),
location = c("01", "01"),
horizon = c("1", "1"),
temporal_resolution = c("wk", "wk"),
target_variable = c("inc death", "inc death"),
target_end_date = c(ymd(20200108), ymd(20200108)),
output_type = c("point", "point"),
output_type_id = c(NA, NA),
value = c(3, 4),
stringsAsFactors = FALSE
)

test_that("Not providing the names of all mandatory columns throws an error", {
simple_test_forecasts |>
as_covid_hub_forecasts(horizon=NULL, target_col="target_variable",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") |>
expect_error()
})

test_that("Providing a column name not in the model_outputs throws an error", {
simple_test_forecasts |>
as_covid_hub_forecasts(location_col="unit", target_col="target_variable",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") |>
expect_error()
})

test_that("Not providing any date columns throws an error", {
simple_test_forecasts |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col="temporal_resolution",
reference_date_col=NULL, target_end_date_col=NULL) |>
expect_error()
})

test_that("Inclusion of multiple types of point forecasts throws an error", {
simple_test_forecasts$output_type <- c("median", "mean")
simple_test_forecasts |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") |>
expect_error()
})

test_that("Not including a model_id column generates a warning", {
simple_test_forecasts |>
dplyr::select(-model_id) |>
as_covid_hub_forecasts(model_id=NULL, target_col="target_variable",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") |>
expect_warning()
})

test_that("The resulting temporal resolution and target variable columns are correctly formatted when a temporal resolution column is not provided", {
actual_temporal_resolution = c("wk", "wk")
actual_target_variable = c("inc death", "inc death")

simple_test_forecasts$target_variable <- c("wk inc death", "wk ahead inc death")
test_forecasts_hub <- simple_test_forecasts |>
dplyr::select(-temporal_resolution) |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col=NULL,
target_end_date_col="target_end_date")

expect_equal(test_forecasts_hub$temporal_resolution, actual_temporal_resolution)
expect_equal(test_forecasts_hub$target_variable, actual_target_variable)
})

test_that("Reference dates are correctly calculated if not provided", {
actual_reference_date = c(ymd(20200101), ymd(20200101))
simple_test_forecasts |>
dplyr::select(-forecast_date) |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col="temporal_resolution",
reference_date=NULL,
target_end_date_col="target_end_date") |>
dplyr::pull(forecast_date) |>
expect_equal(actual_reference_date)
})

test_that("Target end dates are correctly calculated if not provided", {
actual_target_end_date = c(ymd(20200108), ymd(20200108))
simple_test_forecasts |>
dplyr::select(-target_end_date) |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col="temporal_resolution",
reference_date="forecast_date",
target_end_date_col=NULL) |>
dplyr::pull(target_end_date) |>
expect_equal(actual_target_end_date)
})

test_that("Only columns required by the Covid Hub are kept", {
hub_cols <- c("model", "forecast_date", "location", "horizon", "target_variable", "type", "quantile", "value", "temporal_resolution", "target_end_date")
simple_test_forecasts |>
dplyr::mutate(abbreviation="AL", full_location_name="Alabama") |>
as_covid_hub_forecasts(target_col="target_variable",
temp_res_col="temporal_resolution",
target_end_date_col="target_end_date") |>
names() |>
sort() |>
expect_equal(sort(hub_cols))
})

0 comments on commit 8bed140

Please sign in to comment.