From d17f1b1bb187274fb5b021601cd5d44a8d7b8039 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 1 Sep 2022 16:27:31 +0100 Subject: [PATCH] implement process options --- R/estimate_infections.R | 10 +++-- R/opts.R | 93 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/R/estimate_infections.R b/R/estimate_infections.R index d2c4b7533..b095e9fad 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -19,6 +19,10 @@ #' #' @param reported_cases A data frame of confirmed cases (confirm) by date #' (date). confirm must be integer and date must be in date format. +#' @param process_model A character string that defines what is being +#' modelled: "infections", "growth" or "R" (default). If ' set to "R", +#' a generation time distribution needs to be defined via the `generation_time` +#' argument. #' @param generation_time A call to `generation_time_opts()` defining the #' generation time distribution used. For backwards compatibility a list of #' summary parameters can also be passed. @@ -209,7 +213,7 @@ #' options(old_opts) #' } estimate_infections <- function(reported_cases, - model = "R", + process_opts = process_opts(), generation_time = generation_time_opts(), delays = delay_opts(), truncation = trunc_opts(), @@ -280,6 +284,7 @@ estimate_infections <- function(reported_cases, # Define stan model parameters data <- create_stan_data( reported_cases = reported_cases, + process_opts = process_opts, generation_time = generation_time, delays = delays, truncation = truncation, @@ -288,8 +293,7 @@ estimate_infections <- function(reported_cases, obs = obs, backcalc = backcalc, shifted_cases = shifted_cases$confirm, - horizon = horizon, - process_model = process_model + horizon = horizon ) # Set up default settings diff --git a/R/opts.R b/R/opts.R index 2b342e567..5ee7cecca 100644 --- a/R/opts.R +++ b/R/opts.R @@ -67,7 +67,6 @@ generation_time_opts <- function(..., disease, source, max = 15L, fixed = FALSE, names(gt) <- paste0("gt_", names(gt)) return(gt) -} #' Delay Distribution Options #' @@ -187,7 +186,7 @@ trunc_opts <- function(...) { #' Time-Varying Reproduction Number Options #' -#' @description `r lifecycle::badge("stable")` +#' @description `r lifecycle::badge("deprecated")` #' Defines a list specifying the optional arguments for the time-varying reproduction number. #' Custom settings can be supplied which override the defaults. #' @param prior List containing named numeric elements "mean" and "sd". The mean and @@ -230,6 +229,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), future = "latest", gp_on = "R_t-1", pop = 0) { + stop("rt_opts is deprecated - use process_opts instead") rt <- list( prior = prior, use_rt = use_rt, @@ -251,9 +251,93 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), return(rt) } -#' Back Calculation Options +#' Process model optionss #' #' @description `r lifecycle::badge("stable")` +#' Defines a list specifying the optional arguments for the process mode. +#' Custom settings can be supplied which override the defaults. +#' @param prior List containing named numeric elements "mean" and "sd". The mean and +#' standard deviation of the log normal Rt prior. Defaults to mean of 1 and standard +#' deviation of 1. +#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate infections +#' and hence reported cases. +#' @param rw Numeric step size of the random walk, defaults to 0. To specify a weekly random +#' walk set `rw = 7`. For more custom break point settings consider passing in a `breakpoints` +#' variable as outlined in the next section. +#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be used if present +#' as a `breakpoint` variable in the input data. Break points should be defined as 1 if present +#' and otherwise 0. By default breakpoints are fit jointly with a global non-parametric effect +#' and so represent a conservative estimate of break point changes (alter this by setting `gp = NULL`). +#' @param pop Integer, defaults to 0. Susceptible population initially present. Used to adjust +#' Rt estimates when otherwise fixed based on the proportion of the population that is +#' susceptible. When set to 0 no population adjustment is done. +#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the Gaussian process, +#' if in use, should be applied to Rt. Currently supported options are applying the Gaussian +#' process to the last estimated Rt (i.e Rt = Rt-1 * GP), and applying the Gaussian process to +#' a global mean (i.e Rt = R0 * GP). Both should produced comparable results when data is not +#' sparse but the method relying on a global mean will revert to this for real time estimates, +#' which may not be desirable. +#' @return A list of settings defining the time-varying reproduction number +#' @inheritParams create_future_rt +#' @export +#' @examples +#' # default settings +#' rt_opts() +#' +#' # add a custom length scale +#' rt_opts(prior = list(mean = 2, sd = 1)) +#' +#' # add a weekly random walk +#' rt_opts(rw = 7) +#' @importFrom data.table fcase +process_opts <- function(model = "R", + prior_mean = data.table::fcase( + model == "R", list(mean = 1, sd = 1), + model == "growth", list(mean = 0, sd = 1), + model == "infections", NULL + ), + prior_t = NULL, + rw = 0, + use_breakpoints = TRUE, + future = "latest", + stationary = FALSE + pop = 0) { + + ## check + model_choices <- c("infections", "growth", "R") + process_model <- match.arg(process_model, choices = model_choices) + process_model <- which(process_model == model_choices) - 1 + + if (!(xor(is.null(prior_mean), is.null(prior_t)))) { + stop("Either 'prior_mean' or 'prior_t' must be set to NULL") + } + process <- list( + process_model = process_model, + prior_mean = prior_mean, + prior_t = prior_t, + rw = rw, + use_breakpoints = use_breakpoints, + future = future, + stationary = stationary, + pop = pop, + ) + + # replace default settings with those specified by user + if (process$rw > 0) { + process$use_breakpoints <- TRUE + } + + if (!is.null(prior_mean) && + !("mean" %in% names(process$prior) && + "sd" %in% names(process$prior))) { + stop("prior must have both a mean and sd specified") + } + return(process) +} + +#' Back Calculation Options +#' +#' @description `r lifecycle::badge("deprecated")` #' Defines a list specifying the optional arguments for the back calculation #' of cases. Only used if `rt = NULL`. #' @param prior A character string defaulting to "reports". Defines the prior to use @@ -276,7 +360,8 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), #' # default settings #' backcalc_opts() backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) { - backcalc <- list( + stop("backcalc_opts is deprecated - use process_opts instead") + backcalc <- list( prior = match.arg(prior, choices = c("reports", "none", "infections")), prior_window = prior_window, rt_window = as.integer(rt_window)