Skip to content

Commit

Permalink
implement different model types
Browse files Browse the repository at this point in the history
implements suggestions by @hsbadr

See
#213 (comment)
#213 (comment)
#213 (comment)

For now not implementing comparison to the approximate growth rate as
this seems quite a specific use case that could also be done outside the
stan model.

Also not implementing any approximate growth rate from seeding time -
instead minimum seeding time is now set to 1, so the last seeding time
is used to calculate the first growth rate.
  • Loading branch information
sbfnk committed May 17, 2024
1 parent bd2e9cc commit fb795b9
Show file tree
Hide file tree
Showing 21 changed files with 269 additions and 182 deletions.
10 changes: 6 additions & 4 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' }
create_stan_data <- function(data, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {
backcalc, shifted_cases,
process_model) {

cases <- data[(seeding_time + 1):(.N - horizon)]
complete_cases <- create_complete_cases(cases)
Expand All @@ -497,7 +498,8 @@ create_stan_data <- function(data, seeding_time,
t = length(data$date),
horizon = horizon,
burn_in = 0,
seeding_time = seeding_time
seeding_time = seeding_time,
process_model = process_model
)
# add Rt data
stan_data <- c(
Expand Down Expand Up @@ -610,7 +612,7 @@ create_initial_conditions <- function(data) {
out$rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
Expand All @@ -623,7 +625,7 @@ create_initial_conditions <- function(data) {
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01))
}
out$log_R <- array(rnorm(
out$base_cov <- rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
Expand Down
11 changes: 11 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#'
#' @param reported_cases Deprecated; use `data` instead.
#'
#' @param process_model A character string that defines what is being
#' modelled: "infections", "growth" or "R" (default). If ' set to "R",
#' a generation time distribution needs to be defined via the `generation_time`
#' argument.
#'
#' @param generation_time A call to [generation_time_opts()] defining the
#' generation time distribution used. For backwards compatibility a list of
#' summary parameters can also be passed.
Expand Down Expand Up @@ -111,6 +116,7 @@
#' options(old_opts)
#' }
estimate_infections <- function(data,
process_opts = process_opts(),
generation_time = generation_time_opts(),
delays = delay_opts(),
truncation = trunc_opts(),
Expand Down Expand Up @@ -208,10 +214,15 @@ estimate_infections <- function(data,
)
reported_cases <- reported_cases[-(1:backcalc$prior_window)]

model_choices <- c("infections", "growth", "R")
model <- match.arg(model, choices = model_choices)
process_model <- which(model == model_choices) - 1

# Define stan model parameters
stan_data <- create_stan_data(
reported_cases,
seeding_time = seeding_time,
process_opts = process_opts,
rt = rt,
gp = gp,
obs = obs,
Expand Down
4 changes: 2 additions & 2 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$growth_rate <- extract_parameter(
"r",
samples,
reported_dates[-1]
reported_dates
)
if (data$week_effect > 1) {
out$day_of_week <- extract_parameter(
Expand All @@ -233,7 +233,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
date := NULL
]
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples)
out$reporting_overdispersion <- out$reporting_overdispersion[,
value := value.V1][,
Expand Down
90 changes: 88 additions & 2 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ trunc_opts <- function(dist = Fixed(0), tolerance = 0.001,

#' Time-Varying Reproduction Number Options
#'
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the time-varying
#' reproduction number. Custom settings can be supplied which override the
#' defaults.
Expand Down Expand Up @@ -359,6 +359,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
future = "latest",
gp_on = c("R_t-1", "R0"),
pop = 0) {
stop("rt_opts is deprecated - use process_opts instead")
rt <- list(
prior = prior,
use_rt = use_rt,
Expand All @@ -381,9 +382,93 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
return(rt)
}

#' Back Calculation Options
#' Process model optionss
#'
#' @description `r lifecycle::badge("stable")`
#' Defines a list specifying the optional arguments for the process mode.
#' Custom settings can be supplied which override the defaults.
#' @param prior List containing named numeric elements "mean" and "sd". The mean and
#' standard deviation of the log normal Rt prior. Defaults to mean of 1 and standard
#' deviation of 1.
#' @param use_rt Logical, defaults to `TRUE`. Should Rt be used to generate infections
#' and hence reported cases.
#' @param rw Numeric step size of the random walk, defaults to 0. To specify a weekly random
#' walk set `rw = 7`. For more custom break point settings consider passing in a `breakpoints`
#' variable as outlined in the next section.
#' @param use_breakpoints Logical, defaults to `TRUE`. Should break points be used if present
#' as a `breakpoint` variable in the input data. Break points should be defined as 1 if present
#' and otherwise 0. By default breakpoints are fit jointly with a global non-parametric effect
#' and so represent a conservative estimate of break point changes (alter this by setting `gp = NULL`).
#' @param pop Integer, defaults to 0. Susceptible population initially present. Used to adjust
#' Rt estimates when otherwise fixed based on the proportion of the population that is
#' susceptible. When set to 0 no population adjustment is done.
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the Gaussian process,
#' if in use, should be applied to Rt. Currently supported options are applying the Gaussian
#' process to the last estimated Rt (i.e Rt = Rt-1 * GP), and applying the Gaussian process to
#' a global mean (i.e Rt = R0 * GP). Both should produced comparable results when data is not
#' sparse but the method relying on a global mean will revert to this for real time estimates,
#' which may not be desirable.
#' @return A list of settings defining the time-varying reproduction number
#' @inheritParams create_future_rt
#' @export
#' @examples
#' # default settings
#' rt_opts()
#'
#' # add a custom length scale
#' rt_opts(prior = list(mean = 2, sd = 1))
#'
#' # add a weekly random walk
#' rt_opts(rw = 7)
#' @importFrom data.table fcase
process_opts <- function(model = "R",
prior_mean = data.table::fcase(
model == "R", list(mean = 1, sd = 1),
model == "growth", list(mean = 0, sd = 1),
model == "infections", NULL
),
prior_t = NULL,
rw = 0,
use_breakpoints = TRUE,
future = "latest",
stationary = FALSE,
pop = 0) {

## check
model_choices <- c("infections", "growth", "R")
process_model <- match.arg(process_model, choices = model_choices)
process_model <- which(process_model == model_choices) - 1

if (!(xor(is.null(prior_mean), is.null(prior_t)))) {
stop("Either 'prior_mean' or 'prior_t' must be set to NULL")
}
process <- list(
process_model = process_model,
prior_mean = prior_mean,
prior_t = prior_t,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
stationary = stationary,
pop = pop
)

# replace default settings with those specified by user
if (process$rw > 0) {
process$use_breakpoints <- TRUE
}

if (!is.null(prior_mean) &&
!("mean" %in% names(process$prior) &&
"sd" %in% names(process$prior))) {
stop("prior must have both a mean and sd specified")
}
return(process)
}

#' Back Calculation Options
#'
#' @description `r lifecycle::badge("deprecated")`
#' Defines a list specifying the optional arguments for the back calculation
#' of cases. Only used if `rt = NULL`.
#'
Expand Down Expand Up @@ -417,6 +502,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
#' backcalc_opts()
backcalc_opts <- function(prior = c("reports", "none", "infections"),
prior_window = 14, rt_window = 1) {
stop("backcalc_opts is deprecated - use process_opts instead")
backcalc <- list(
prior = arg_match(prior),
prior_window = prior_window,
Expand Down
1 change: 0 additions & 1 deletion inst/stan/data/backcalc.stan
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
int backcalc_prior; // Prior type to use for backcalculation
int rt_half_window; // Half the moving average window used when calculating Rt
7 changes: 7 additions & 0 deletions inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int cov_mean_const; // 0 = not const mean; 1 = const mean
real<lower = 0> cov_mean_mean[cov_mean_const]; // const covariate mean
real<lower = 0> cov_mean_sd[cov_mean_const]; // const covariate sd
vector<lower = 0>[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean
2 changes: 1 addition & 1 deletion inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7)
int model_type; // type of model: 0 = poisson otherwise negative binomial
int obs_dist; // type of model: 0 = poisson otherwise negative binomial
real phi_mean; // Mean and sd of the normal prior for the
real phi_sd; // reporting process
int week_effect; // length of week effect
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/observations.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
int t; // unobserved time
int lt; // timepoints in the likelihood
int seeding_time; // time period used for seeding and not observed
int<lower = 1> seeding_time; // time period used for seeding and not observed
int horizon; // forecast horizon
int future_time; // time in future for Rt
array[lt] int<lower = 0> cases; // observed cases
Expand Down
4 changes: 0 additions & 4 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
int estimate_r; // should the reproduction no be estimated (1 = yes)
real prior_infections; // prior for initial infections
real prior_growth; // prior on initial growth rate
real <lower = 0> r_mean; // prior mean of reproduction number
real <lower = 0> r_sd; // prior standard deviation of reproduction number
int bp_n; // no of breakpoints (0 = no breakpoints)
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
Expand Down
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
array[n, week_effect] real<lower = 0> day_of_week_simplex;
int obs_scale;
array[n, obs_scale] real<lower = 0, upper = 1> frac_obs;
int model_type;
array[n, model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
int obs_dist;
array[n, obs_dist] real<lower = 0> rep_phi; // overdispersion of the reporting process
int<lower = 0> trunc_id; // id of truncation
Loading

0 comments on commit fb795b9

Please sign in to comment.