-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #398 from reichlab/as_covid_hub_forecasts
As covid hub forecasts function
- Loading branch information
Showing
4 changed files
with
384 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
#' Reformat model outputs stored as a `model_output_tbl` class (or similar) to | ||
#' that of a `data.frame` formatted according to standards of the COVID-19 | ||
#' Forecasting Hub which can be processed by functions from the `covidHubUtils` | ||
#' package such as `score_forecasts()` or `plot_forecasts()`. The supplied | ||
#' `model_output_tbl` should have columns defining properties akin to | ||
#' reference dates, locations, horizons, and targets. | ||
#' | ||
#' @param model_outputs an object of class `model_output_tbl` with component | ||
#' model outputs (e.g., predictions). Should have columns containing the | ||
#' following information: model name, reference date or target end date, | ||
#' location, horizon, target, temporal resolution*, output type, output | ||
#' type id, and value. Note that the temporal resolution may be included | ||
#' in the target column. | ||
#' @param model_id_col `character` string of the name of the column | ||
#' containing the model name(s) for the forecasts. Defaults to "model_id". | ||
#' Should be set to NULL if no such column exists, in which case a model_id | ||
#' column will be created populated with the value "model_id". | ||
#' @param location_col `character` string of the name of the column | ||
#' containing the locations for the forecasts. Defaults to "location". | ||
#' @param horizon_col `character` string of the name of the column | ||
#' containing the horizons for the forecasts. Defaults to "horizon". | ||
#' @param target_col `character` string of the name of the column | ||
#' containing the targets for the forecasts. Defaults to "target". If | ||
#' `temp_res_col` is NULL, the target column in `model_outputs` is assumed | ||
#' to contain targets of the form "temporal_resolution target" or | ||
#' "temporal_resolution ahead target", such as "wk ahead inc flu hosp" | ||
#' "wk inc flu hosp". | ||
#' @param reference_date_col `character` string of the name of the column | ||
#' containing the reference dates for the forecasts. Defaults to | ||
#' "forecast_date". Should be set to NULL if no such column exists, in which | ||
#' case the column will be created using the following information: | ||
#' horizon, target end date, and temporal resolution. | ||
#' @param target_end_date_col `character` string of the name of the column | ||
#' containing the target end dates for the forecasts. Defaults to | ||
#' "target_end_date". Should be set to NULL if no such column exists, in | ||
#' which case the column will be created using the following information: | ||
#' horizon, forecast date, and temporal resolution. | ||
#' @param output_type_col `character` string of the name of the column | ||
#' containing the output types for the forecasts. Defaults to "output_type". | ||
#' @param output_type_id_col `character` string of the name of the column | ||
#' containing the output type ids for the forecasts. Defaults to | ||
#' "output_type_id". | ||
#' @param value_col `character` string of the name of the column | ||
#' containing the values for the forecasts. Defaults to "value". | ||
#' @param temp_res_col `character` string of the name of the column | ||
#' containing the temporal resolutions for the forecasts. Defaults to | ||
#' "temporal_resolution". Should be set to NULL if no such column exists, | ||
#' in which case the column will be created from the existing target column. | ||
#' | ||
#' @return a `data.frame` of reformatted model outputs that may be fed into | ||
#' any of the `covidHubUtils` functions with 10 total columns: model, | ||
#' forecast_date, location, horizon, temporal_resolution, target_variable, | ||
#' target_end_date, type, quantile, value. Other columns are removed. | ||
#' @export | ||
#' | ||
#' @examples | ||
#' library(dplyr) | ||
#' forecasts <- load_forecasts( | ||
#' models = c("COVIDhub-ensemble", "UMass-MechBayes"), | ||
#' dates = "2020-12-14", | ||
#' date_window_size = 7, | ||
#' locations = c("US"), | ||
#' targets = paste(1:4, "wk ahead inc death"), | ||
#' source = "zoltar" | ||
#' ) | ||
#' altered_forecasts <- forecasts |> # Alter forecasts to not be CovidHub format | ||
#' dplyr::rename(model_id=model, output_type=type, output_type_id=quantile) |> | ||
#' dplyr::mutate(target_variable = "wk ahead inc death", horizon=as.numeric(horizon)) |> | ||
#' dplyr::select(-temporal_resolution) | ||
#' formatted_forecasts <- as_covid_hub_forecasts( | ||
#' altered_forecasts, | ||
#' target_col="target_variable", | ||
#' temp_res_col=NULL | ||
#' ) |> | ||
#' dplyr::mutate(horizon=as.character(horizon)) | ||
#' testthat::expect_equal(formatted_forecasts, dplyr::select(forecasts, model:value)) | ||
|
||
as_covid_hub_forecasts <- function(model_outputs, model_id_col = "model_id", | ||
reference_date_col="forecast_date", | ||
location_col="location", | ||
horizon_col="horizon", target_col="target", | ||
output_type_col="output_type", | ||
output_type_id_col="output_type_id", | ||
value_col="value", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") { | ||
|
||
provided_names <- c(model_id_col, reference_date_col, location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col, temp_res_col, target_end_date_col) | ||
mandatory_cols <- list(location_col, horizon_col, target_col, output_type_col, output_type_id_col, value_col) | ||
|
||
if (any(purrr::map_lgl(mandatory_cols, is.null))) { | ||
stop("You must provide the names of columns with location, horizon, target, output_type, output_type_id, and value information.") | ||
} | ||
|
||
if (isFALSE(all(provided_names %in% names(model_outputs)))) { | ||
stop("Not all provided column names exist in the provided model_outputs.") | ||
} | ||
|
||
if (is.null(reference_date_col) & is.null(target_end_date_col)) { | ||
stop("You must provide the name of at least one date column.") | ||
} | ||
|
||
if (all(c("mean", "median") %in% unique(model_outputs[[output_type_col]]))){ | ||
stop("You may only have one type of point forecast.") | ||
} | ||
|
||
if (is.null(model_id_col)) { | ||
warning("No model_id_col provided, creating one automatically.") | ||
model_id_col = "model_id" | ||
model_outputs <- dplyr::mutate(model_outputs, model_id = "model_id", .before = 1) | ||
} | ||
|
||
model_outputs <- model_outputs |> | ||
dplyr::rename(model = model_id_col, | ||
type = output_type_col, quantile = output_type_id_col, | ||
forecast_date = reference_date_col, location = location_col, | ||
value = value_col, target_variable = target_col) |> | ||
dplyr::mutate(horizon = as.numeric(horizon)) | ||
|
||
if (is.null(temp_res_col)) { | ||
model_outputs <- model_outputs |> | ||
dplyr::rename(target = target_variable) |> | ||
dplyr::mutate(target = ifelse( | ||
stringr::str_detect(target, "ahead"), | ||
stringr::str_replace(target, "ahead", "") |> stringr::str_squish(), | ||
target)) |> | ||
tidyr::separate(target, sep=" ", convert=TRUE, into=c("temporal_resolution", "target_variable"), extra="merge") | ||
} | ||
|
||
if (is.null(reference_date_col)) { | ||
model_outputs <- model_outputs |> | ||
dplyr::mutate(forecast_date=dplyr::case_when( | ||
temporal_resolution %in% c("d", "day") ~ target_end_date - lubridate::days(horizon), | ||
temporal_resolution %in% c("w", "wk", "week") ~ target_end_date - lubridate::weeks(horizon), | ||
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ target_end_date %m-% months(horizon), | ||
temporal_resolution %in% c("y", "yr", "year") ~ target_end_date - lubridate::years(horizon), | ||
.default = target_end_date), | ||
.before = type) | ||
} | ||
|
||
if (is.null(target_end_date_col)) { | ||
model_outputs <- model_outputs |> | ||
dplyr::mutate(target_end_date=dplyr::case_when( | ||
temporal_resolution %in% c("d", "day") ~ forecast_date + lubridate::days(horizon), | ||
temporal_resolution %in% c("w", "wk", "week") ~ forecast_date + lubridate::weeks(horizon), | ||
temporal_resolution %in% c("m", "mth", "mnth", "month") ~ forecast_date %m+% months(horizon), | ||
temporal_resolution %in% c("y", "yr", "year") ~ forecast_date + lubridate::years(horizon), | ||
.default = forecast_date), | ||
.before = type) | ||
} | ||
|
||
covid_hub_outputs <- model_outputs |> | ||
dplyr::mutate(horizon=as.character(horizon), | ||
type = ifelse(type %in% c("mean", "median"), "point", type)) |> | ||
dplyr::select(model, forecast_date, location, horizon, temporal_resolution, | ||
target_variable, target_end_date, type, quantile, value) | ||
|
||
return (covid_hub_outputs) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
library(dplyr) | ||
library(lubridate) | ||
library(stringr) | ||
|
||
simple_test_forecasts <- data.frame( | ||
model_id = c("source1", "source2"), | ||
forecast_date = c(ymd(20200101), ymd(20200101)), | ||
location = c("01", "01"), | ||
horizon = c("1", "1"), | ||
temporal_resolution = c("wk", "wk"), | ||
target_variable = c("inc death", "inc death"), | ||
target_end_date = c(ymd(20200108), ymd(20200108)), | ||
output_type = c("point", "point"), | ||
output_type_id = c(NA, NA), | ||
value = c(3, 4), | ||
stringsAsFactors = FALSE | ||
) | ||
|
||
test_that("Not providing the names of all mandatory columns throws an error", { | ||
simple_test_forecasts |> | ||
as_covid_hub_forecasts(horizon=NULL, target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") |> | ||
expect_error() | ||
}) | ||
|
||
test_that("Providing a column name not in the model_outputs throws an error", { | ||
simple_test_forecasts |> | ||
as_covid_hub_forecasts(location_col="unit", target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") |> | ||
expect_error() | ||
}) | ||
|
||
test_that("Not providing any date columns throws an error", { | ||
simple_test_forecasts |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
reference_date_col=NULL, target_end_date_col=NULL) |> | ||
expect_error() | ||
}) | ||
|
||
test_that("Inclusion of multiple types of point forecasts throws an error", { | ||
simple_test_forecasts$output_type <- c("median", "mean") | ||
simple_test_forecasts |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") |> | ||
expect_error() | ||
}) | ||
|
||
test_that("Not including a model_id column generates a warning", { | ||
simple_test_forecasts |> | ||
dplyr::select(-model_id) |> | ||
as_covid_hub_forecasts(model_id=NULL, target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") |> | ||
expect_warning() | ||
}) | ||
|
||
test_that("The resulting temporal resolution and target variable columns are correctly formatted when a temporal resolution column is not provided", { | ||
actual_temporal_resolution = c("wk", "wk") | ||
actual_target_variable = c("inc death", "inc death") | ||
|
||
simple_test_forecasts$target_variable <- c("wk inc death", "wk ahead inc death") | ||
test_forecasts_hub <- simple_test_forecasts |> | ||
dplyr::select(-temporal_resolution) |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col=NULL, | ||
target_end_date_col="target_end_date") | ||
|
||
expect_equal(test_forecasts_hub$temporal_resolution, actual_temporal_resolution) | ||
expect_equal(test_forecasts_hub$target_variable, actual_target_variable) | ||
}) | ||
|
||
test_that("Reference dates are correctly calculated if not provided", { | ||
actual_reference_date = c(ymd(20200101), ymd(20200101)) | ||
simple_test_forecasts |> | ||
dplyr::select(-forecast_date) |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
reference_date=NULL, | ||
target_end_date_col="target_end_date") |> | ||
dplyr::pull(forecast_date) |> | ||
expect_equal(actual_reference_date) | ||
}) | ||
|
||
test_that("Target end dates are correctly calculated if not provided", { | ||
actual_target_end_date = c(ymd(20200108), ymd(20200108)) | ||
simple_test_forecasts |> | ||
dplyr::select(-target_end_date) |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
reference_date="forecast_date", | ||
target_end_date_col=NULL) |> | ||
dplyr::pull(target_end_date) |> | ||
expect_equal(actual_target_end_date) | ||
}) | ||
|
||
test_that("Only columns required by the Covid Hub are kept", { | ||
hub_cols <- c("model", "forecast_date", "location", "horizon", "target_variable", "type", "quantile", "value", "temporal_resolution", "target_end_date") | ||
simple_test_forecasts |> | ||
dplyr::mutate(abbreviation="AL", full_location_name="Alabama") |> | ||
as_covid_hub_forecasts(target_col="target_variable", | ||
temp_res_col="temporal_resolution", | ||
target_end_date_col="target_end_date") |> | ||
names() |> | ||
sort() |> | ||
expect_equal(sort(hub_cols)) | ||
}) |