Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added forecast model wrapper #73

Merged
merged 11 commits into from
May 27, 2020
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export(evaluate_model)
export(fable_model)
export(forecastHybrid_model)
export(forecast_cases)
export(forecast_model)
export(forecast_rt)
export(iterative_case_forecast)
export(iterative_rt_forecast)
Expand Down Expand Up @@ -74,6 +75,7 @@ importFrom(stats,rpois)
importFrom(stats,rt)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(stats,ts)
importFrom(tibble,tibble)
importFrom(tidyr,expand_grid)
importFrom(tidyr,gather)
Expand Down
85 changes: 83 additions & 2 deletions R/model-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ fable_model <- function(y = NULL, samples = NULL,
return(samples)
}

#' forecastHybrid model wrapper
#' ForecastHybrid model wrapper
nikosbosse marked this conversation as resolved.
Show resolved Hide resolved
#'
#' Allows users to forecast using ensembles from the `forecastHybrid` package. Note that
#' whilst weighted ensembles can be created this is not advised when samples > 1 as currently
Expand Down Expand Up @@ -293,6 +293,88 @@ forecastHybrid_model <- function(y = NULL, samples = NULL,



#' Forecast model wrapper
#'
#' Allows users to forecast using models from the `forecast` package.
#' Note that `forecast` must be installed for this model wrapper to be functional.
#' @param model A `forecast` model object.
#' @inheritParams bsts_model
#' @param ... pass further arguments to the forecast models
#' @export
#' @return A dataframe of predictions (with columns representing the
#' time horizon and rows representing samples).
#'
#' @importFrom stats ts
#' @importFrom purrr map
#' @importFrom dplyr bind_rows
#'
#' @examples \dontrun{
#'
#' ## Used on its own
#' forecast_model(y = EpiSoon::example_obs_rts[1:10, ]$rt,
nikosbosse marked this conversation as resolved.
Show resolved Hide resolved
#' model = forecast::auto.arima,
#' samples = 10, horizon = 7)
#'
#'
#' forecast_rt(EpiSoon::example_obs_rts[1:10, ],
#' model = function(...){
#' forecast_model(model = forecast::ets, ...)},
#' horizon = 7, samples = 10)
#'
#'
#'
#' models <- list("ARIMA" = function(...) {forecast_model(model = forecast::auto.arima, ...)},
#' "ETS" = function(...) {forecast_model(model = forecast::ets, ...)},
#' "TBATS" = function(...) {forecast_model(model = forecast::tbats, ...)})
#'
#' ## Compare models
#' evaluations <- compare_models(EpiSoon::example_obs_rts,
#' EpiSoon::example_obs_cases, models,
#' horizon = 7, samples = 10,
#' serial_interval = example_serial_interval)
#'
#' plot_forecast_evaluation(evaluations$forecast_rts,
#' EpiSoon::example_obs_rts,
#' horizon_to_plot = 7) +
#' ggplot2::facet_grid(~ model) +
#' cowplot::panel_border()
#'}
#'

forecast_model <- function(y = NULL, samples = NULL,
nikosbosse marked this conversation as resolved.
Show resolved Hide resolved
horizon = NULL, model = NULL,
...) {

check_suggests("forecast")

# convert to timeseries object
timeseries <- stats::ts(y)

# fit and forecast
fit <- model(timeseries)
prediction <- forecast::forecast(fit, h = horizon)

## Extract samples and tidy format
sample_from_model <- prediction

if (samples == 1) {
sample_from_model <- data.frame(t(as.data.frame(sample_from_model$mean)))
rownames(sample_from_model) <- NULL
}else{
mean <- as.numeric(prediction$mean)
upper <- prediction$upper[, ncol(prediction$upper)]
lower <- prediction$lower[, ncol(prediction$lower)]
sd <- (upper - lower) / 3.92
sample_from_model <- purrr::map2(mean, sd,
nikosbosse marked this conversation as resolved.
Show resolved Hide resolved
~ rnorm(samples, mean = .x, sd = .y))

sample_from_model <- dplyr::bind_cols(sample_from_model)
}

return(sample_from_model)
}


#' Stack models according to CRPS
#'
#' @description
Expand Down Expand Up @@ -470,4 +552,3 @@ stackr_model <- function(y = NULL,
}



2 changes: 1 addition & 1 deletion man/forecastHybrid_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 61 additions & 0 deletions man/forecast_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.